"include/vscode:/vscode.git/clone" did not exist on "b4e77c8cc53eae7dc42b60aede0a01a5991a7f72"
Commit 5076982b authored by Chao Liu's avatar Chao Liu
Browse files

format

parent 8c03672b
...@@ -391,9 +391,8 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -391,9 +391,8 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{all_transforms, remove_cv_t<decltype(element_space_size)>>{
element_space_size, all_transforms, element_space_size, real_size};
real_size};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
......
...@@ -72,8 +72,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -72,8 +72,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
const auto element_space_size = f(f, Number<0>{}, Number<1>{}); const auto element_space_size = f(f, Number<0>{}, Number<1>{});
#else #else
const auto real_size = const auto real_size = calculate_element_space_size_impl(
calculate_element_space_size_impl(lengths, strides, Number<0>{}, integral_constant<std::size_t, 1ul>{}); lengths, strides, Number<0>{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
...@@ -84,9 +84,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -84,9 +84,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{
element_space_size, transforms, element_space_size, real_size};
real_size};
} }
// Lengths... can be: // Lengths... can be:
...@@ -116,9 +115,8 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -116,9 +115,8 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{
element_space_size, transforms, element_space_size, real_size};
real_size};
} }
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
......
...@@ -211,7 +211,8 @@ struct ReductionHost ...@@ -211,7 +211,8 @@ struct ReductionHost
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(IndexDataType i = 0; i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size()); i++) for(IndexDataType i = 0; i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size());
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -246,7 +247,9 @@ struct ReductionHost ...@@ -246,7 +247,9 @@ struct ReductionHost
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(IndexDataType i = 0; i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size()); i++) for(IndexDataType i = 0;
i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size());
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
......
...@@ -70,18 +70,26 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -70,18 +70,26 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(int n = 0; n < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]); ++n) for(int n = 0; n < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]);
++n)
{ {
for(int ho = 0; ho < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]); ++ho) for(int ho = 0;
ho < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]);
++ho)
{ {
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] -
arg.in_left_pads_[I0]; arg.in_left_pads_[I0];
for(int wo = 0; wo < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]); ++wo) for(int wo = 0;
wo < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]);
++wo)
{ {
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] -
arg.in_left_pads_[I1]; arg.in_left_pads_[I1];
if(hi >= 0 && hi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) && if(hi >= 0 &&
wi >= 0 && wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])) hi <
ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]))
{ {
float v_out; float v_out;
float v_in; float v_in;
......
...@@ -88,13 +88,16 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -88,13 +88,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]);
++c)
{ {
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++x) for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]);
++x)
{ {
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
if(wi >= 0 && wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2])) if(wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]))
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -128,17 +131,23 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -128,17 +131,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]);
++c)
{ {
for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++y) for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]);
++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); ++x) for(int x = 0;
x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]);
++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1]; arg.in_left_pads_[1];
if(hi >= 0 && hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) && wi >= 0 && if(hi >= 0 &&
hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3])) wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3]))
{ {
float v_in; float v_in;
...@@ -174,23 +183,34 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -174,23 +183,34 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]);
++c)
{ {
for(int z = 0; z < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++z) for(int z = 0; z < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]);
++z)
{ {
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); ++y) for(int y = 0;
y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]);
++y)
{ {
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1]; arg.in_left_pads_[1];
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[4]); ++x) for(int x = 0;
x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[4]);
++x)
{ {
int wi = wo * arg.conv_strides_[2] + int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) && if(di >= 0 &&
hi >= 0 && hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3]) && di < ck::type_convert<int>(
wi >= 0 && wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[4])) arg.input_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < ck::type_convert<int>(
arg.input_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[4]))
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -73,17 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -73,17 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[1]);
++c)
{ {
for(int y = 0; y < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[2]); ++y) for(int y = 0;
y < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[2]);
++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
for(int x = 0; x < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[3]); ++x) for(int x = 0;
x < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[3]);
++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1]; arg.in_left_pads_[1];
if(hi >= 0 && hi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) && wi >= 0 && if(hi >= 0 &&
hi <
ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])) wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]))
{ {
float v_in; float v_in;
......
...@@ -76,17 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -76,17 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[1]);
++c)
{ {
for(int y = 0; y < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[2]); ++y) for(int y = 0;
y < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[2]);
++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
for(int x = 0; x < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[3]); ++x) for(int x = 0;
x < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[3]);
++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1]; arg.in_left_pads_[1];
if(hi >= 0 && hi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) && wi >= 0 && if(hi >= 0 &&
hi <
ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])) wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]))
{ {
float v_in; float v_in;
......
...@@ -78,8 +78,8 @@ ConvParams::ConvParams(ck::index_t n_dim, ...@@ -78,8 +78,8 @@ ConvParams::ConvParams(ck::index_t n_dim,
ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial) ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{ {
throw(std::runtime_error( throw(
"ConvParams::GetOutputSpatialLengths: " std::runtime_error("ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!")); "parameter size is different from number of declared dimensions!"));
} }
} }
...@@ -93,8 +93,8 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const ...@@ -93,8 +93,8 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial) ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{ {
throw(std::runtime_error( throw(
"ConvParams::GetOutputSpatialLengths: " std::runtime_error("ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!")); "parameter size is different from number of declared dimensions!"));
} }
...@@ -103,8 +103,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const ...@@ -103,8 +103,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t idx_eff = const ck::index_t idx_eff = (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
(filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
out_spatial_len[i] = out_spatial_len[i] =
(input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) /
conv_filter_strides[i] + conv_filter_strides[i] +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment