Commit 0dc0af2e authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fix comments

parent 53fdf365
......@@ -19,6 +19,7 @@ None
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched Gemm DL (#732)
- Introduce wrapper sublibrary (limited functionality) (#1071)
### Changes
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
......
add_executable(client_tensor_transform tensor_transform.cpp)
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
......@@ -34,7 +34,7 @@ Current CK library are structured into 4 layers:
* "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer
* "Wrapper for tensor tranforms operations"
* "Wrapper for tensor transform operations"
* "Client API" layer
.. image:: data/ck_layer.png
......
......@@ -5,8 +5,11 @@ Wrapper
-------------------------------------
Description
-------------------------------------
Note: Wrapper is currently under development. At the moment, its functionality
is limited.
.. note::
The wrapper is under development and its functionality is limited.
CK provides a lightweight wrapper for more complex operations implemented in
the library. It allows indexing of nested layouts using a simple interface
......@@ -33,7 +36,6 @@ Example:
Output::
dims:4,(2,4) strides:2,(1,8)
Print2d
0 1 8 9 16 17 24 25
2 3 10 11 18 19 26 27
4 5 12 13 20 21 28 29
......
......@@ -109,7 +109,7 @@ struct Layout
// Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple(
auto aligned_shape = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<IdxDims...>>>::value)
......@@ -124,7 +124,7 @@ struct Layout
Number<Tuple<IdxDims...>::Size()>{});
// Unroll and process next step
return AlignShapeToIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
UnrollNestedTuple<0, 1>(idx));
}
}
......@@ -144,7 +144,7 @@ struct Layout
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
// Merge nested shape dims
// Merge nested shape dims. Merge nested shape dims when idx is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
......@@ -205,10 +205,9 @@ struct Layout
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto unrolled_shape_via_idx = AlignShapeToIdx(shape, idx);
const auto aligned_shape = AlignShapeToIdx(shape, idx);
// Transform correct form of shape
return CreateMergedDescriptor(
unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
}
}
......@@ -224,7 +223,7 @@ struct Layout
}
public:
// If stride not passed, deduce from GenerateColumnMajorPackedStrides
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
......
......@@ -356,7 +356,7 @@ TEST(TestLayoutHelpers, SizeAndGet)
EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0);
EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0);
// Acces via new layout (using get on layout)
// Access through new layout (using get with layout object)
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment