Commit 1eb7d83b authored by myamlak's avatar myamlak
Browse files

Resolving large tensor size issue.

parent 97f133d2
...@@ -114,11 +114,12 @@ struct TensorDescriptor ...@@ -114,11 +114,12 @@ struct TensorDescriptor
__host__ __device__ constexpr TensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size) ElementSpaceSize element_space_size,
std::size_t real_size)
: transforms_{transforms}, : transforms_{transforms},
element_size_{InitializeElementSize(transforms)}, element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size} element_space_size_{element_space_size},
real_size_{real_size}
{ {
static_assert(Transforms::Size() == ntransform_ && static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionIdss::Size() == ntransform_ && LowerDimensionIdss::Size() == ntransform_ &&
...@@ -154,6 +155,8 @@ struct TensorDescriptor ...@@ -154,6 +155,8 @@ struct TensorDescriptor
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
__host__ __device__ constexpr auto GetRealSize() const { return real_size_; }
template <typename Idx> template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
{ {
...@@ -213,6 +216,9 @@ struct TensorDescriptor ...@@ -213,6 +216,9 @@ struct TensorDescriptor
Transforms transforms_; Transforms transforms_;
ElementSize element_size_; ElementSize element_size_;
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
private:
std::size_t real_size_;
}; };
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
...@@ -379,12 +385,15 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -379,12 +385,15 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
const auto real_size = old_tensor_desc.GetRealSize();
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>, return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{all_transforms, remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size}; element_space_size,
real_size};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
......
...@@ -68,10 +68,16 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -68,10 +68,16 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
} }
}; };
const auto real_size = f(f, Number<0>{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = f(f, Number<0>{}, Number<1>{}); const auto element_space_size = f(f, Number<0>{}, Number<1>{});
#else #else
const auto real_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif #endif
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
...@@ -79,7 +85,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -79,7 +85,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size}; element_space_size,
real_size};
} }
// Lengths... can be: // Lengths... can be:
...@@ -100,6 +107,9 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -100,6 +107,9 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto real_size =
container_reduce(lengths, math::multiplies{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{}); const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
...@@ -107,7 +117,8 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -107,7 +117,8 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size}; element_space_size,
real_size};
} }
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
......
...@@ -864,15 +864,15 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -864,15 +864,15 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// Input tensors can't be bigger than 2GB each. // Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 1e9; constexpr std::size_t GB2 = 1e9;
if(ck::type_convert<std::size_t>(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize()) > GB2) if(arg.a_grid_desc_k0_m_k1_.GetRealSize() > GB2)
{ {
return false; return false;
} }
if(ck::type_convert<std::size_t>(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize()) > GB2) if(arg.b_grid_desc_k0_n_k1_.GetRealSize() > GB2)
{ {
return false; return false;
} }
if(ck::type_convert<std::size_t>(arg.c_grid_desc_m_n_.GetElementSpaceSize()) > GB2) if(arg.c_grid_desc_m_n_.GetRealSize() > GB2)
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment