Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4f0b87dc
Commit
4f0b87dc
authored
Dec 04, 2023
by
Artur Wojcik
Browse files
Merge branch 'uif2-initial' into uif2-migraphx
parents
696f0839
bc5b84b1
Changes
66
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
307 additions
and
34 deletions
+307
-34
client_example/17_grouped_gemm_fastgelu/CMakeLists.txt
client_example/17_grouped_gemm_fastgelu/CMakeLists.txt
+1
-1
client_example/18_groupnorm/CMakeLists.txt
client_example/18_groupnorm/CMakeLists.txt
+1
-1
client_example/19_pool/CMakeLists.txt
client_example/19_pool/CMakeLists.txt
+4
-4
client_example/20_splitk_gemm/CMakeLists.txt
client_example/20_splitk_gemm/CMakeLists.txt
+1
-1
client_example/21_grouped_gemm_bias/CMakeLists.txt
client_example/21_grouped_gemm_bias/CMakeLists.txt
+1
-1
client_example/22_grouped_gemm/CMakeLists.txt
client_example/22_grouped_gemm/CMakeLists.txt
+3
-3
client_example/22_im2col_col2im/CMakeLists.txt
client_example/22_im2col_col2im/CMakeLists.txt
+2
-2
client_example/23_elementwise_transpose/CMakeLists.txt
client_example/23_elementwise_transpose/CMakeLists.txt
+1
-1
client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt
..._grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt
+4
-4
client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt
..._example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt
+4
-4
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+2
-2
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
+2
-2
example/27_layernorm2d_fwd/run_layernorm_example.inc
example/27_layernorm2d_fwd/run_layernorm_example.inc
+2
-2
example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc
example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc
+2
-2
example/44_elementwise_permute/CMakeLists.txt
example/44_elementwise_permute/CMakeLists.txt
+3
-1
example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc
example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc
+2
-2
example/64_tensor_transforms/CMakeLists.txt
example/64_tensor_transforms/CMakeLists.txt
+2
-0
example/64_tensor_transforms/tensor_transform.cpp
example/64_tensor_transforms/tensor_transform.cpp
+150
-0
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
...e/64_tensor_transforms/tensor_transform_using_wrapper.cpp
+119
-0
No files found.
client_example/17_grouped_gemm_fastgelu/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp
)
target_link_libraries
(
client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations
)
\ No newline at end of file
target_link_libraries
(
client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
\ No newline at end of file
client_example/18_groupnorm/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_groupnorm_swish groupnorm_swish.cpp
)
target_link_libraries
(
client_groupnorm_swish PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_groupnorm_swish PRIVATE composable_kernel::device_
other_
operations
)
client_example/19_pool/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_max_pool2d_fwd max_pool2d_fwd.cpp
)
target_link_libraries
(
client_max_pool2d_fwd PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_max_pool2d_fwd PRIVATE composable_kernel::device_
other_
operations
)
add_executable
(
client_max_pool2d_bwd max_pool2d_bwd.cpp
)
target_link_libraries
(
client_max_pool2d_bwd PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_max_pool2d_bwd PRIVATE composable_kernel::device_
other_
operations
)
add_executable
(
client_avg_pool3d_fwd avg_pool3d_fwd.cpp
)
target_link_libraries
(
client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_avg_pool3d_fwd PRIVATE composable_kernel::device_
other_
operations
)
add_executable
(
client_avg_pool3d_bwd avg_pool3d_bwd.cpp
)
target_link_libraries
(
client_avg_pool3d_bwd PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_avg_pool3d_bwd PRIVATE composable_kernel::device_
other_
operations
)
client_example/20_splitk_gemm/CMakeLists.txt
View file @
4f0b87dc
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"fp16"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_splitK_gemm splitK_gemm_fp16_f8.cpp
)
target_link_libraries
(
client_splitK_gemm PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_splitK_gemm PRIVATE composable_kernel::device_
gemm_
operations
)
endif
()
client_example/21_grouped_gemm_bias/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_
gemm_
operations
)
client_example/22_grouped_gemm/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_
gemm_
operations
)
add_executable
(
client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_
gemm_
operations
)
add_executable
(
client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_
gemm_
operations
)
client_example/22_im2col_col2im/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_image_to_column image_to_column.cpp
)
target_link_libraries
(
client_image_to_column PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_image_to_column PRIVATE composable_kernel::device_
other_
operations
)
add_executable
(
client_column_to_image column_to_image.cpp
)
target_link_libraries
(
client_column_to_image PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_column_to_image PRIVATE composable_kernel::device_
other_
operations
)
client_example/23_elementwise_transpose/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_elementwise_transpose3d elementwise_transpose_3d.cpp
)
target_link_libraries
(
client_elementwise_transpose3d PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_elementwise_transpose3d PRIVATE composable_kernel::device_
other_
operations
)
client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_
conv_
operations
)
client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt
View file @
4f0b87dc
add_executable
(
client_grouped_convnd_fwd_scaleadd_ab_fp32 grouped_conv_fwd_scaleadd_ab_fp32.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_convnd_fwd_scaleadd_ab_fp16 grouped_conv_fwd_scaleadd_ab_fp16.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_convnd_fwd_scaleadd_ab_bf16 grouped_conv_fwd_scaleadd_ab_bf16.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_convnd_fwd_scaleadd_ab_int8 grouped_conv_fwd_scaleadd_ab_int8.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_
conv_
operations
)
client_example/CMakeLists.txt
View file @
4f0b87dc
...
...
@@ -47,7 +47,7 @@ else()
endif
()
endif
()
find_package
(
composable_kernel COMPONENTS device_operations
)
find_package
(
composable_kernel COMPONENTS device_
other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_
operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm /opt/rocm/llvm $ENV{HIP_PATH} $ENV{ROCM_PATH}
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
4f0b87dc
...
...
@@ -299,8 +299,8 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ks
.
push_back
(
128
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
View file @
4f0b87dc
...
...
@@ -300,8 +300,8 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ks
.
push_back
(
128
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/27_layernorm2d_fwd/run_layernorm_example.inc
View file @
4f0b87dc
example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc
View file @
4f0b87dc
example/44_elementwise_permute/CMakeLists.txt
View file @
4f0b87dc
...
...
@@ -5,4 +5,6 @@ add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permu
add_example_executable
(
example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp
)
add_example_executable
(
example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp
)
add_example_executable
(
example_elementwise_permute elementwise_permute.cpp
)
add_example_executable
(
example_elementwise_permute_3d elementwise_permute_3d.cpp
)
if
((
NOT GPU_TARGETS MATCHES
"gfx940"
)
AND
(
NOT GPU_TARGETS MATCHES
"gfx941"
)
AND
(
NOT GPU_TARGETS MATCHES
"gfx942"
))
add_example_executable
(
example_elementwise_permute_3d elementwise_permute_3d.cpp
)
endif
()
example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc
View file @
4f0b87dc
example/64_tensor_transforms/CMakeLists.txt
0 → 100644
View file @
4f0b87dc
add_example_executable
(
example_tensor_transform tensor_transform.cpp
)
add_example_executable
(
example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp
)
example/64_tensor_transforms/tensor_transform.cpp
0 → 100644
View file @
4f0b87dc
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
using
DataType
=
int
;
template
<
typename
Desc
>
void
Print1d
(
const
Desc
&
desc
)
{
std
::
cout
<<
"Print1d"
<<
std
::
endl
;
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I0
);
w
++
)
{
std
::
cout
<<
desc
.
CalculateOffset
(
ck
::
make_tuple
(
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
template
<
typename
Desc
>
void
Print2d
(
const
Desc
&
desc
)
{
std
::
cout
<<
"Print2d"
<<
std
::
endl
;
for
(
ck
::
index_t
h
=
0
;
h
<
desc
.
GetLength
(
I0
);
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I1
);
w
++
)
{
std
::
cout
<<
desc
.
CalculateOffset
(
ck
::
make_tuple
(
h
,
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
}
template
<
typename
Desc
>
void
Print3dCustom
(
const
Desc
&
desc
)
{
std
::
cout
<<
"Print3dCustom"
<<
std
::
endl
;
for
(
ck
::
index_t
d
=
0
;
d
<
desc
.
GetLength
(
I0
);
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
desc
.
GetLength
(
I1
);
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I2
);
w
++
)
{
std
::
cout
<<
desc
.
CalculateOffset
(
ck
::
make_tuple
(
d
,
h
,
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
}
}
int
main
()
{
// Tensor descriptor traverse in row-major (need to reverse dims)
std
::
cout
<<
"Note: Tensor descriptor traverse in row-major"
<<
std
::
endl
;
// Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,4)
const
auto
desc_4x8_s1x4
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
8
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}));
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
Print2d
(
desc_4x8_s1x4
);
using
Cord1x1Type
=
ck
::
Tuple
<
ck
::
Number
<
1
>
,
ck
::
Number
<
1
>>
;
constexpr
ck
::
index_t
offset_1x1
=
desc_4x8_s1x4
.
CalculateOffset
(
Cord1x1Type
{});
std
::
cout
<<
"Constexpr calculated [1, 1] offset:"
<<
offset_1x1
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const
auto
desc_4x2x4_s2x1x8
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
4
,
2
,
4
),
ck
::
make_tuple
(
2
,
1
,
8
));
// Transform to 2d (column-major, need to to reverse dims)
const
auto
desc_4x2x4_s2x1x8_merged
=
ck
::
transform_tensor_descriptor
(
desc_4x2x4_s2x1x8
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
4
),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
2
,
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
std
::
cout
<<
"dims:4,(2,4) strides:2,(1,8)"
<<
std
::
endl
;
Print2d
(
desc_4x2x4_s2x1x8_merged
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const
auto
desc_2x2x2x4_s1x4x2x8
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
2
,
2
,
2
,
4
),
ck
::
make_tuple
(
1
,
4
,
2
,
8
));
// Transform to 2d
const
auto
desc_2x2x2x4_s1x4x2x8_double_merged_2d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
2
)),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
3
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
// Transform to 3d
const
auto
desc_2x2x2x4_s1x4x2x8_double_merged_3d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
2
),
ck
::
make_pass_through_transform
(
2
),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
3
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
std
::
cout
<<
"dims:(2,2),(2,4) strides:(1,4),(2,8)"
<<
std
::
endl
;
Print2d
(
desc_2x2x2x4_s1x4x2x8_double_merged_2d
);
Print3dCustom
(
desc_2x2x2x4_s1x4x2x8_double_merged_3d
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const
auto
desc_2x2x2x4_s1x4x2x8_nested
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
2
,
2
,
2
,
4
),
ck
::
make_tuple
(
1
,
4
,
2
,
8
));
const
auto
desc_2x2x2x4_s1x4x2x8_nested_merged_3d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8_nested
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
2
)),
ck
::
make_pass_through_transform
(
2
),
ck
::
make_pass_through_transform
(
4
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
const
auto
desc_2x2x2x4_s1x4x2x8_nested_merged_1d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8_nested
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
,
2
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_2x2x2x4_s1x4x2x8_nested_merged_2d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
4
)),
ck
::
make_pass_through_transform
(
4
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
std
::
cout
<<
"dims:((2,2),2),4 strides:((1,4),2),8"
<<
std
::
endl
;
Print1d
(
desc_2x2x2x4_s1x4x2x8_nested_merged_1d
);
Print2d
(
desc_2x2x2x4_s1x4x2x8_nested_merged_2d
);
Print3dCustom
(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d
);
return
0
;
}
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
0 → 100644
View file @
4f0b87dc
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "tensor_transform_wrapper.hpp"
using
DataType
=
int
;
template
<
typename
Layout
>
void
Print1d
(
const
Layout
&
layout
)
{
std
::
cout
<<
"Print1d"
<<
std
::
endl
;
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
(
layout
);
w
++
)
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
template
<
typename
Layout
>
void
Print2d
(
const
Layout
&
layout
)
{
std
::
cout
<<
"Print2d"
<<
std
::
endl
;
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
tensor_transform_wrapper
::
size
<
0
>
(
layout
);
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
layout
);
w
++
)
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
h
,
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
}
// Print in (x,y),z pattern
template
<
typename
Layout
>
void
Print3dCustom
(
const
Layout
&
layout
)
{
std
::
cout
<<
"Print3dCustom"
<<
std
::
endl
;
for
(
ck
::
index_t
d
=
0
;
d
<
ck
::
tensor_transform_wrapper
::
size
<
0
>
(
ck
::
tensor_transform_wrapper
::
get
<
0
>
(
layout
));
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
ck
::
tensor_transform_wrapper
::
get
<
0
>
(
layout
));
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
layout
);
w
++
)
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d
,
h
),
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
}
}
int
main
()
{
// Layout traverse in row-major
std
::
cout
<<
"Note: Layout traverse in column-major"
<<
std
::
endl
;
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
// (dims:4,8 strides:1,4)
const
auto
shape_4x8
=
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
8
>
{});
const
auto
layout_4x8_s1x4
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x8
);
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
Print2d
(
layout_4x8_s1x4
);
using
Cord1x1Type
=
ck
::
Tuple
<
ck
::
Number
<
1
>
,
ck
::
Number
<
1
>>
;
constexpr
ck
::
index_t
offset_1x1
=
layout_4x8_s1x4
.
template
operator
()
<
Cord1x1Type
>();
std
::
cout
<<
"Constexpr calculated [1, 1] offset:"
<<
offset_1x1
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
// dims:4,(2,4) strides:2,(1,8)
const
auto
shape_4x2x4
=
ck
::
make_tuple
(
4
,
ck
::
make_tuple
(
2
,
4
));
const
auto
strides_s2x1x8
=
ck
::
make_tuple
(
2
,
ck
::
make_tuple
(
1
,
8
));
const
auto
layout_4x2x4_s2x1x8
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x2x4
,
strides_s2x1x8
);
std
::
cout
<<
"dims:4,(2,4) strides:2,(1,8)"
<<
std
::
endl
;
Print2d
(
layout_4x2x4_s2x1x8
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const
auto
shape_2x2x2x4
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
4
>
{}));
const
auto
strides_s1x4x2x8
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
8
>
{}));
static
const
auto
layout_2x2x2x4_s1x4x2x8
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_2x2x2x4
,
strides_s1x4x2x8
);
std
::
cout
<<
"dims:(2,2),(2,4) strides:(1,4),(2,8)"
<<
std
::
endl
;
Print2d
(
layout_2x2x2x4_s1x4x2x8
);
Print3dCustom
(
layout_2x2x2x4_s1x4x2x8
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const
auto
shape_2x2x2x4_nested
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
2
>
{}),
ck
::
Number
<
2
>
{}),
ck
::
Number
<
4
>
{});
const
auto
strides_s1x4x2x8_nested
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
Number
<
2
>
{}),
ck
::
Number
<
8
>
{});
static
const
auto
layout_2x2x2x4_s1x4x2x8_nested
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_2x2x2x4_nested
,
strides_s1x4x2x8_nested
);
std
::
cout
<<
"dims:((2,2),2),4 strides:((1,4),2),8"
<<
std
::
endl
;
Print1d
(
layout_2x2x2x4_s1x4x2x8_nested
);
Print2d
(
layout_2x2x2x4_s1x4x2x8_nested
);
Print3dCustom
(
layout_2x2x2x4_s1x4x2x8_nested
);
return
0
;
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment