Commit 89140d16 authored by Chao Liu's avatar Chao Liu
Browse files

tweaking

parent 58584c29
......@@ -122,7 +122,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
// output tensor
#if 0 // debug
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
......@@ -142,6 +143,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#else
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#endif
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
......
......@@ -41,11 +41,18 @@ struct PassThrough
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
#if 0
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
#else
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
#endif
};
// LowerLengths: Sequence<...>
......@@ -156,6 +163,7 @@ struct Merge
return idx_low;
}
#if 0
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
......@@ -239,6 +247,108 @@ struct Merge
return idx_low_new - idx_low_old;
}
#else
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old)
{
if(idx_up_diff[0] == 0)
{
return make_zero_array<index_t, nDimLow>();
}
else
{
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
// find out the last low dimension that changed
index_t last_changed_low_dim = 0;
static_for<0, nDimLow, 1>{}([&](auto i) {
if(idx_low_diff_tmp[i] != 0)
{
last_changed_low_dim = i;
}
});
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
if(idx_up_diff[0] > 0)
{
// do carry check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool carry = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
}
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
}
}
else
{
// do borrow check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool borrow = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
}
}
return idx_low_new - idx_low_old;
}
}
#endif
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
......
......@@ -98,7 +98,29 @@ struct NativeTensorCoordinate
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
}
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
#if 0 // debug
__host__ __device__ static constexpr bool HasValidOffset() { return true; }
#else
// evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluated at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
// For native tensor, offset is valid if upper-index is valid
return IsUpperIndexValid();
}
// evaluated at compile-time
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
{
// For native tensor, offset is valid if upper-index is valid
return true;
}
#endif
private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
......@@ -143,6 +165,8 @@ struct TransformedTensorCoordinate
__host__ __device__ constexpr const UpperIndex& GetUpperIndex() const { return mIndexUp; }
__host__ __device__ constexpr const LowerIndex& GetLowerIndex() const { return mIndexLow.GetIndex(); }
__host__ __device__ constexpr const UpperIndex& GetIndex() const { return GetUpperIndex(); }
__host__ __device__ constexpr const index_t& GetOffset() const
......@@ -206,11 +230,38 @@ struct TransformedTensorCoordinate
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
}
#if 0 // debug
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
{
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
mCoordLow.IsUpperIndexMappedToValidOffset();
}
#else
// evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluted at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
{
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(GetLowerIndex());
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
{
return IsLowerIndexValidAssumingUpperIndexIsValid() &&
GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
}
#endif
private:
// mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
......
......@@ -120,11 +120,30 @@ struct NativeTensorDescriptor
return Tuple<>{};
}
#if 0 // debug
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
{
return true;
}
#else
__host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
{
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
});
return flag;
}
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const Index& idx)
{
return IsUpperIndexValid(idx) && IsValidUpperIndexAlwaysMappedToValidOffset();
}
#endif
};
// Tensor descriptor for "transformed tensor"
......@@ -467,6 +486,7 @@ struct TransformedTensorDescriptor
}
#endif
#if 0
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
{
......@@ -494,6 +514,48 @@ struct TransformedTensorDescriptor
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
CalculateLowerIndex(idx_up));
}
#else
//
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
{
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
}
return flag;
}
// this function tells you: Is lower-index valid, assuming upper index is valid?
__host__ __device__ constexpr bool
IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low) const
{
bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
// check a indtransformation if it does not always has a valid mapping
if(!tran.IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_low_part =
to_array(pick_array_element(idx_low, LowerDimensionIds{}.At(itran)));
constexpr auto lengths_low_part =
GetLowerTenosrDescriptor().GetLengths()(LowerDimensionIds{});
for(index_t i = 0; i < LowerDimensionIds::Size(); ++i)
{
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < lengths_low_part[i];
}
}
});
return flag;
}
#endif
};
} // namespace ck
......
......@@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
if(src_coord.HasValidOffset())
{
move_data<SrcData,
SrcDataPerRead,
......@@ -142,7 +142,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
if(dst_coord.HasValidOffset())
{
move_data<DstData,
DstDataPerWrite,
......@@ -260,7 +260,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
if(src_coord.HasValidOffset())
{
move_data<SrcData,
SrcDataPerRead,
......@@ -299,7 +299,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
if(dst_coord.HasValidOffset())
{
move_data<DstData,
DstDataPerWrite,
......@@ -399,7 +399,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsUpperIndexMappedToValidOffset())
if(src_coord.HasValidOffset())
{
move_data<SrcData,
SrcDataPerRead,
......@@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if(dst_coord.IsUpperIndexMappedToValidOffset())
if(dst_coord.HasValidOffset())
{
move_data<DstData,
DstDataPerWrite,
......
......@@ -29,9 +29,10 @@ struct static_for
{
__host__ __device__ constexpr static_for()
{
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert((NEnd - NBegin) % Increment == 0,
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should have NBegin <= NEnd");
}
template <class F>
......
......@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 0
#if 1
// 1x1
constexpr index_t N = 256;
constexpr index_t C = 1024;
......@@ -36,7 +36,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 1x7
constexpr index_t N = 128;
constexpr index_t C = 1024;
......@@ -51,7 +51,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -64,21 +64,6 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
......@@ -306,7 +291,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment