Commit addf20e4 authored by Chao Liu's avatar Chao Liu
Browse files

revert change to tensor descriptor; promote lementSpaceSize to 64bit

parent 5076982b
...@@ -114,12 +114,11 @@ struct TensorDescriptor ...@@ -114,12 +114,11 @@ struct TensorDescriptor
__host__ __device__ constexpr TensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size, ElementSpaceSize element_space_size)
std::size_t real_size)
: transforms_{transforms}, : transforms_{transforms},
element_size_{InitializeElementSize(transforms)}, element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size}, element_space_size_{element_space_size}
real_size_{real_size}
{ {
static_assert(Transforms::Size() == ntransform_ && static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionIdss::Size() == ntransform_ && LowerDimensionIdss::Size() == ntransform_ &&
...@@ -155,8 +154,6 @@ struct TensorDescriptor ...@@ -155,8 +154,6 @@ struct TensorDescriptor
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
__host__ __device__ constexpr auto GetRealSize() const { return real_size_; }
template <typename Idx> template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
{ {
...@@ -216,9 +213,6 @@ struct TensorDescriptor ...@@ -216,9 +213,6 @@ struct TensorDescriptor
Transforms transforms_; Transforms transforms_;
ElementSize element_size_; ElementSize element_size_;
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
private:
std::size_t real_size_;
}; };
template <index_t NDimHidden, typename VisibleDimensionIds> template <index_t NDimHidden, typename VisibleDimensionIds>
...@@ -385,14 +379,12 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -385,14 +379,12 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
const auto real_size = old_tensor_desc.GetRealSize();
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>, return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{ remove_cv_t<decltype(element_space_size)>>{all_transforms,
all_transforms, element_space_size, real_size}; element_space_size};
} }
template <typename TensorDesc, typename VisibleIndex> template <typename TensorDesc, typename VisibleIndex>
......
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP #pragma once
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
...@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
} }
#endif #endif
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
...@@ -68,29 +72,26 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -68,29 +72,26 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
} }
}; };
const auto real_size = f(f, Number<0>{}, integral_constant<std::size_t, 1ul>{}); const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
#else #else
const auto real_size = calculate_element_space_size_impl(
lengths, strides, Number<0>{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif #endif
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{ remove_cv_t<decltype(element_space_size)>>{transforms,
transforms, element_space_size, real_size}; element_space_size};
} }
// Lengths... can be: // Lengths... could be:
// 1) index_t, which is known at run-time // 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time // 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths> template <typename... Lengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
...@@ -106,19 +107,22 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -106,19 +107,22 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto real_size = const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
container_reduce(lengths, math::multiplies{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{ remove_cv_t<decltype(element_space_size)>>{transforms,
transforms, element_space_size, real_size}; element_space_size};
} }
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) Number<>
template <typename... Lengths, typename Align> template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align) make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
...@@ -155,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali ...@@ -155,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali
} }
} // namespace ck } // namespace ck
#endif
...@@ -639,11 +639,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -639,11 +639,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, d_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(), compute_base_ptr_of_batch_{
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
c_grid_desc_m_n_.GetElementSpaceSize(), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
d_grid_desc_m_.GetElementSpaceSize(), type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
d_grid_desc_m_.GetElementSpaceSize()}, type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{}, block_2_ctile_map_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
......
...@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl ...@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(), compute_ptr_offset_of_batch_{
b_grid_desc_k0_n_k1_.GetElementSpaceSize(), type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
c_grid_desc_m_n_.GetElementSpaceSize()}, type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
......
...@@ -862,17 +862,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -862,17 +862,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// Input tensors can't be bigger than 2GB each. // Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 1e9; constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
if(arg.a_grid_desc_k0_m_k1_.GetRealSize() > GB2) if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
{ arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
return false; arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2)
}
if(arg.b_grid_desc_k0_n_k1_.GetRealSize() > GB2)
{
return false;
}
if(arg.c_grid_desc_m_n_.GetRealSize() > GB2)
{ {
return false; return false;
} }
......
...@@ -8,5 +8,8 @@ namespace ck { ...@@ -8,5 +8,8 @@ namespace ck {
template <index_t N> template <index_t N>
using Number = integral_constant<index_t, N>; using Number = integral_constant<index_t, N>;
template <index_t N>
using LongNumber = integral_constant<long_index_t, N>;
} // namespace ck } // namespace ck
#endif #endif
...@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>) ...@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>)
return StaticBuffer<AddressSpace, T, N, true>{}; return StaticBuffer<AddressSpace, T, N, true>{};
} }
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
return StaticBuffer<AddressSpace, T, N, true>{};
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment