Unverified Commit f03a1738 authored by myamlak's avatar myamlak Committed by GitHub
Browse files

Resolution of issue #153: Add compiler warning on comparing int and size_t (#212)



* Turning compare warnings on

* Cleaning part I

* Cleaning part II

* Explicit static_cast to ck::type_convert

* Resolving large tensor size issue.

* format

* revert change to tensor descriptor; promote lementSpaceSize to 64bit

* use integer value for GEMM test

* Review remarks

* Review remarks + issues with (un)signed arithmetic

* Format fix

* Format

* Clang-format.

* fix 2gb limit issue
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
Co-authored-by: default avatarAdam Osewski <aosewski@amd.com>
parent 968bd932
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-sign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
) )
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
......
...@@ -140,7 +140,7 @@ class SimpleAppArgs ...@@ -140,7 +140,7 @@ class SimpleAppArgs
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
{ {
unsigned int ch; int ch;
while(1) while(1)
{ {
......
...@@ -80,8 +80,8 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -80,8 +80,8 @@ static void pool_host_verify(const Tensor<InDataType>& in,
for(int x = 0; x < window_spatial_lengths[1]; ++x) for(int x = 0; x < window_spatial_lengths[1]; ++x)
{ {
int wi = wo * window_strides[1] + x - in_left_pads[1]; int wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < ck::type_convert<int>(in.mDesc.GetLengths()[2]) && wi >= 0 &&
wi < in.mDesc.GetLengths()[3]) wi < ck::type_convert<int>(in.mDesc.GetLengths()[3]))
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
......
...@@ -131,7 +131,7 @@ int main(int argc, char* argv[]) ...@@ -131,7 +131,7 @@ int main(int argc, char* argv[])
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{ {
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{})));
...@@ -168,7 +168,7 @@ int main(int argc, char* argv[]) ...@@ -168,7 +168,7 @@ int main(int argc, char* argv[])
} }
} }
for(int i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{ {
a_tensors_device.emplace_back( a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
...@@ -213,7 +213,7 @@ int main(int argc, char* argv[]) ...@@ -213,7 +213,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
for(int i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{ {
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
......
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP #pragma once
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
...@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
} }
#endif #endif
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
...@@ -68,10 +72,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -68,10 +72,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
} }
}; };
const auto element_space_size = f(f, Number<0>{}, Number<1>{}); const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
#else #else
const auto element_space_size = const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
#endif #endif
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
...@@ -82,9 +86,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -82,9 +86,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
element_space_size}; element_space_size};
} }
// Lengths... can be: // Lengths... could be:
// 1) index_t, which is known at run-time // 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time // 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths> template <typename... Lengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
...@@ -100,7 +107,7 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -100,7 +107,7 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{}); const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
...@@ -110,6 +117,12 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -110,6 +117,12 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
element_space_size}; element_space_size};
} }
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) Number<>
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align) make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
...@@ -146,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali ...@@ -146,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali
} }
} // namespace ck } // namespace ck
#endif
...@@ -635,11 +635,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -635,11 +635,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, d_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(), compute_base_ptr_of_batch_{
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
c_grid_desc_m_n_.GetElementSpaceSize(), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
d_grid_desc_m_.GetElementSpaceSize(), type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
d_grid_desc_m_.GetElementSpaceSize()}, type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{}, block_2_ctile_map_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
......
...@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl ...@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(), compute_ptr_offset_of_batch_{
b_grid_desc_k0_n_k1_.GetElementSpaceSize(), type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
c_grid_desc_m_n_.GetElementSpaceSize()}, type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
......
...@@ -697,7 +697,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -697,7 +697,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
// Gridwise GEMM size // Gridwise GEMM size
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -1412,7 +1412,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1412,7 +1412,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} }
// Gridwise GEMM size // Gridwise GEMM size
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -861,17 +861,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -861,17 +861,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// Input tensors can't be bigger than 2GB each. // Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 2 * 1e9; constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2) if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
{ arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
return false; arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2)
}
if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2)
{
return false;
}
if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
{ {
return false; return false;
} }
......
...@@ -372,17 +372,18 @@ struct DeviceGroupedGemmXdl ...@@ -372,17 +372,18 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
group_count_ = static_cast<int>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
group_count_ == p_c.size())) group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
{ {
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
} }
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
for(index_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{ {
const index_t M = gemm_shapes[i].M; const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_shapes[i].N;
...@@ -563,7 +564,7 @@ struct DeviceGroupedGemmXdl ...@@ -563,7 +564,7 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_) if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
return false; return false;
else else
return true; return true;
......
...@@ -8,5 +8,8 @@ namespace ck { ...@@ -8,5 +8,8 @@ namespace ck {
template <index_t N> template <index_t N>
using Number = integral_constant<index_t, N>; using Number = integral_constant<index_t, N>;
template <index_t N>
using LongNumber = integral_constant<long_index_t, N>;
} // namespace ck } // namespace ck
#endif #endif
...@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>) ...@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>)
return StaticBuffer<AddressSpace, T, N, true>{}; return StaticBuffer<AddressSpace, T, N, true>{};
} }
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
return StaticBuffer<AddressSpace, T, N, true>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -211,7 +211,8 @@ struct ReductionHost ...@@ -211,7 +211,8 @@ struct ReductionHost
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++) for(IndexDataType i = 0; i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size());
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -246,7 +247,9 @@ struct ReductionHost ...@@ -246,7 +247,9 @@ struct ReductionHost
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++) for(IndexDataType i = 0;
i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size());
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
......
...@@ -154,7 +154,7 @@ struct ParallelTensorFunctor ...@@ -154,7 +154,7 @@ struct ParallelTensorFunctor
{ {
std::array<std::size_t, NDIM> indices; std::array<std::size_t, NDIM> indices;
for(int idim = 0; idim < NDIM; ++idim) for(std::size_t idim = 0; idim < NDIM; ++idim)
{ {
indices[idim] = i / mStrides[idim]; indices[idim] = i / mStrides[idim];
i -= indices[idim] * mStrides[idim]; i -= indices[idim] * mStrides[idim];
...@@ -316,7 +316,7 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -316,7 +316,7 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
constexpr float eps = 1e-10; constexpr float eps = 1e-10;
for(int i = 0; i < ref.mData.size(); ++i) for(std::size_t i = 0; i < ref.mData.size(); ++i)
{ {
float ref_v = ck::type_convert<float>(ref.mData[i]); float ref_v = ck::type_convert<float>(ref.mData[i]);
float result_v = ck::type_convert<float>(result.mData[i]); float result_v = ck::type_convert<float>(result.mData[i]);
......
...@@ -70,18 +70,25 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -70,18 +70,25 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n)
{ {
for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho) for(std::size_t ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
{ {
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
arg.in_left_pads_[I0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{ {
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - auto wi =
arg.in_left_pads_[I1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
......
...@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0]; auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[0]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[0]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = auto w_tmp =
wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[1]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0]; auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
int do_ = d_tmp / arg.conv_strides_[0]; auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
if(do_ >= 0 && do_ < Do) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = auto h_tmp =
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[1]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[2] - auto w_tmp =
x * arg.conv_dilations_[2]; ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[2]; auto wo =
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
......
...@@ -88,13 +88,16 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -88,13 +88,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{ {
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - auto wi =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -128,18 +131,26 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -128,18 +131,26 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.input_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -174,23 +185,37 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -174,23 +185,37 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{ {
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - auto di =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - auto hi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{ {
int wi = wo * arg.conv_strides_[2] + auto wi =
x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; ck::type_convert<ck::long_index_t>(wo *
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] && arg.conv_strides_[2]) +
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] && ck::type_convert<ck::long_index_t>(x *
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4]) arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment