Commit addf20e4 authored by Chao Liu's avatar Chao Liu
Browse files

revert change to tensor descriptor; promote lementSpaceSize to 64bit

parent 5076982b
......@@ -114,12 +114,11 @@ struct TensorDescriptor
__host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size,
std::size_t real_size)
ElementSpaceSize element_space_size)
: transforms_{transforms},
element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size},
real_size_{real_size}
element_space_size_{element_space_size}
{
static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionIdss::Size() == ntransform_ &&
......@@ -155,8 +154,6 @@ struct TensorDescriptor
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
__host__ __device__ constexpr auto GetRealSize() const { return real_size_; }
template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
{
......@@ -216,9 +213,6 @@ struct TensorDescriptor
Transforms transforms_;
ElementSize element_size_;
ElementSpaceSize element_space_size_;
private:
std::size_t real_size_;
};
template <index_t NDimHidden, typename VisibleDimensionIds>
......@@ -385,14 +379,12 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
const auto real_size = old_tensor_desc.GetRealSize();
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{
all_transforms, element_space_size, real_size};
remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size};
}
template <typename TensorDesc, typename VisibleIndex>
......
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "multi_index_transform_helper.hpp"
......@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
}
#endif
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths,
typename... Strides,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
......@@ -68,29 +72,26 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
}
};
const auto real_size = f(f, Number<0>{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
#else
const auto real_size = calculate_element_space_size_impl(
lengths, strides, Number<0>{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
#endif
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{
transforms, element_space_size, real_size};
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
// Lengths... can be:
// 1) index_t, which is known at run-time
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths>
__host__ __device__ constexpr auto
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
......@@ -106,19 +107,22 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto real_size =
container_reduce(lengths, math::multiplies{}, integral_constant<std::size_t, 1ul>{});
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{
transforms, element_space_size, real_size};
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) Number<>
template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
......@@ -155,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali
}
} // namespace ck
#endif
......@@ -639,11 +639,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(),
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize(),
d_grid_desc_m_.GetElementSpaceSize(),
d_grid_desc_m_.GetElementSpaceSize()},
compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......
......@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize()},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......
......@@ -862,17 +862,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg)
{
// Input tensors can't be bigger than 2GB each.
constexpr std::size_t GB2 = 1e9;
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
if(arg.a_grid_desc_k0_m_k1_.GetRealSize() > GB2)
{
return false;
}
if(arg.b_grid_desc_k0_n_k1_.GetRealSize() > GB2)
{
return false;
}
if(arg.c_grid_desc_m_n_.GetRealSize() > GB2)
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2)
{
return false;
}
......
......@@ -8,5 +8,8 @@ namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t N>
using LongNumber = integral_constant<long_index_t, N>;
} // namespace ck
#endif
......@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>)
return StaticBuffer<AddressSpace, T, N, true>{};
}
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
{
return StaticBuffer<AddressSpace, T, N, true>{};
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment