"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "820ec821f1e2932ed78df540eb90bdd0b8297e55"
Commit e1832bcb authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fix for getLength

parent 81b79a77
...@@ -65,9 +65,9 @@ int main() ...@@ -65,9 +65,9 @@ int main()
{ {
// Layout traverse in row-major // Layout traverse in row-major
std::cout << "Note: Layout traverse in column-major" << std::endl; std::cout << "Note: Layout traverse in column-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor) // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
// (dims:4,8 strides:1,4) // (dims:4,8 strides:1,4)
const auto shape_4x8 = ck::make_tuple(4, 8); const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{});
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,4" << std::endl; std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(layout_4x8_s1x4); Print2d(layout_4x8_s1x4);
...@@ -75,12 +75,10 @@ int main() ...@@ -75,12 +75,10 @@ int main()
constexpr ck::index_t offset_0x0 = layout_4x8_s1x4.template operator()<Cord0x0Type>(); constexpr ck::index_t offset_0x0 = layout_4x8_s1x4.template operator()<Cord0x0Type>();
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl; std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
// dims:4,(2,4) strides:2,(1,8) // dims:4,(2,4) strides:2,(1,8)
const auto shape_4x2x4 = const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
ck::make_tuple(ck::Number<4>{}, ck::make_tuple(ck::Number<2>{}, ck::Number<4>{})); const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto strides_s2x1x8 =
ck::make_tuple(ck::Number<2>{}, ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}));
const auto layout_4x2x4_s2x1x8 = const auto layout_4x2x4_s2x1x8 =
ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8); ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
......
...@@ -317,9 +317,9 @@ struct Layout ...@@ -317,9 +317,9 @@ struct Layout
const auto elem = shape_.At(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{ {
using UnrolledElement = decltype(UnrollNestedTuple(elem)); const auto unrolled_element = UnrollNestedTuple(elem);
return TupleReduce<I0.value, UnrolledElement::Size()>( return TupleReduce<I0.value, unrolled_element.Size()>(
[](auto x, auto y) { return x * y; }, UnrolledElement{}); [](auto x, auto y) { return x * y; }, unrolled_element);
} }
else else
{ {
...@@ -339,9 +339,9 @@ struct Layout ...@@ -339,9 +339,9 @@ struct Layout
const auto elem = shape_.At(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{ {
using UnrolledElement = decltype(UnrollNestedTuple(elem)); const auto unrolled_element = UnrollNestedTuple(elem);
return TupleReduce<I0.value, UnrolledElement::Size()>( return TupleReduce<I0.value, unrolled_element.Size()>(
[](auto x, auto y) { return x * y; }, UnrolledElement{}); [](auto x, auto y) { return x * y; }, unrolled_element);
} }
else else
{ {
...@@ -356,9 +356,9 @@ struct Layout ...@@ -356,9 +356,9 @@ struct Layout
*/ */
__host__ __device__ constexpr index_t GetLength() const __host__ __device__ constexpr index_t GetLength() const
{ {
using UnrolledShape = decltype(UnrollNestedTuple(shape_)); const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{}); unrolled_shape);
} }
/** /**
...@@ -368,9 +368,9 @@ struct Layout ...@@ -368,9 +368,9 @@ struct Layout
*/ */
__host__ __device__ constexpr index_t GetLength() __host__ __device__ constexpr index_t GetLength()
{ {
using UnrolledShape = decltype(UnrollNestedTuple(shape_)); const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{}); unrolled_shape);
} }
/** /**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment