Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
91949d6e
Commit
91949d6e
authored
Dec 06, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
bb819abc
836b7e55
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
501 additions
and
174 deletions
+501
-174
.github/dependabot.yml
.github/dependabot.yml
+6
-0
.gitignore
.gitignore
+0
-1
.readthedocs.yaml
.readthedocs.yaml
+5
-5
CHANGELOG.md
CHANGELOG.md
+1
-0
client_example/25_tensor_transforms/CMakeLists.txt
client_example/25_tensor_transforms/CMakeLists.txt
+4
-0
client_example/25_tensor_transforms/tensor_transform.cpp
client_example/25_tensor_transforms/tensor_transform.cpp
+0
-0
client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp
...e/25_tensor_transforms/tensor_transform_using_wrapper.cpp
+13
-18
docs/conf.py
docs/conf.py
+19
-8
docs/doxygen/Doxyfile
docs/doxygen/Doxyfile
+4
-2
docs/index.rst
docs/index.rst
+2
-0
docs/sphinx/_toc.yml.in
docs/sphinx/_toc.yml.in
+3
-3
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+2
-4
docs/wrapper.rst
docs/wrapper.rst
+54
-0
example/64_tensor_transforms/CMakeLists.txt
example/64_tensor_transforms/CMakeLists.txt
+0
-2
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+12
-0
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+51
-130
include/ck/wrapper/layout_utils.hpp
include/ck/wrapper/layout_utils.hpp
+321
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/wrapper/CMakeLists.txt
test/wrapper/CMakeLists.txt
+2
-0
No files found.
.github/dependabot.yml
View file @
91949d6e
...
@@ -10,3 +10,9 @@ updates:
...
@@ -10,3 +10,9 @@ updates:
open-pull-requests-limit
:
10
open-pull-requests-limit
:
10
schedule
:
schedule
:
interval
:
"
daily"
interval
:
"
daily"
labels
:
-
"
documentation"
-
"
dependencies"
-
"
ci:docs-only"
reviewers
:
-
"
samjwu"
.gitignore
View file @
91949d6e
...
@@ -54,7 +54,6 @@ _images/
...
@@ -54,7 +54,6 @@ _images/
_static/
_static/
_templates/
_templates/
_toc.yml
_toc.yml
docBin/
_doxygen/
_doxygen/
# JetBrains IDE
# JetBrains IDE
...
...
.readthedocs.yaml
View file @
91949d6e
...
@@ -3,11 +3,6 @@
...
@@ -3,11 +3,6 @@
version
:
2
version
:
2
build
:
os
:
ubuntu-22.04
tools
:
python
:
"
3.8"
sphinx
:
sphinx
:
configuration
:
docs/conf.py
configuration
:
docs/conf.py
...
@@ -16,3 +11,8 @@ formats: [htmlzip, pdf, epub]
...
@@ -16,3 +11,8 @@ formats: [htmlzip, pdf, epub]
python
:
python
:
install
:
install
:
-
requirements
:
docs/sphinx/requirements.txt
-
requirements
:
docs/sphinx/requirements.txt
build
:
os
:
ubuntu-22.04
tools
:
python
:
"
3.8"
CHANGELOG.md
View file @
91949d6e
...
@@ -19,6 +19,7 @@ None
...
@@ -19,6 +19,7 @@ None
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
-
Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
-
Support for Batched Gemm DL (#732)
-
Support for Batched Gemm DL (#732)
-
Introduce wrapper sublibrary (limited functionality) (#1071)
### Changes
### Changes
-
Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
-
Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
...
...
client_example/25_tensor_transforms/CMakeLists.txt
0 → 100644
View file @
91949d6e
add_executable
(
client_tensor_transform tensor_transform.cpp
)
target_link_libraries
(
client_tensor_transform PRIVATE composable_kernel::device_other_operations
)
add_executable
(
client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp
)
target_link_libraries
(
client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations
)
example/
64
_tensor_transforms/tensor_transform.cpp
→
client_
example/
25
_tensor_transforms/tensor_transform.cpp
View file @
91949d6e
File moved
example/
64
_tensor_transforms/tensor_transform_using_wrapper.cpp
→
client_
example/
25
_tensor_transforms/tensor_transform_using_wrapper.cpp
View file @
91949d6e
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence.hpp"
#include "
tensor_transform_wrapper
.hpp"
#include "
ck/wrapper/layout
.hpp"
using
DataType
=
int
;
using
DataType
=
int
;
...
@@ -17,7 +17,7 @@ template <typename Layout>
...
@@ -17,7 +17,7 @@ template <typename Layout>
void
Print1d
(
const
Layout
&
layout
)
void
Print1d
(
const
Layout
&
layout
)
{
{
std
::
cout
<<
"Print1d"
<<
std
::
endl
;
std
::
cout
<<
"Print1d"
<<
std
::
endl
;
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_
wrapper
::
size
(
layout
);
w
++
)
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
wrapper
::
size
(
layout
);
w
++
)
{
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
w
))
<<
" "
;
std
::
cout
<<
layout
(
ck
::
make_tuple
(
w
))
<<
" "
;
}
}
...
@@ -28,9 +28,9 @@ template <typename Layout>
...
@@ -28,9 +28,9 @@ template <typename Layout>
void
Print2d
(
const
Layout
&
layout
)
void
Print2d
(
const
Layout
&
layout
)
{
{
std
::
cout
<<
"Print2d"
<<
std
::
endl
;
std
::
cout
<<
"Print2d"
<<
std
::
endl
;
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
tensor_transform_
wrapper
::
size
<
0
>
(
layout
);
h
++
)
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
wrapper
::
size
<
0
>
(
layout
);
h
++
)
{
{
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_
wrapper
::
size
<
1
>
(
layout
);
w
++
)
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
wrapper
::
size
<
1
>
(
layout
);
w
++
)
{
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
h
,
w
))
<<
" "
;
std
::
cout
<<
layout
(
ck
::
make_tuple
(
h
,
w
))
<<
" "
;
}
}
...
@@ -43,15 +43,11 @@ template <typename Layout>
...
@@ -43,15 +43,11 @@ template <typename Layout>
void
Print3dCustom
(
const
Layout
&
layout
)
void
Print3dCustom
(
const
Layout
&
layout
)
{
{
std
::
cout
<<
"Print3dCustom"
<<
std
::
endl
;
std
::
cout
<<
"Print3dCustom"
<<
std
::
endl
;
for
(
ck
::
index_t
d
=
0
;
for
(
ck
::
index_t
d
=
0
;
d
<
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout
));
d
++
)
d
<
ck
::
tensor_transform_wrapper
::
size
<
0
>
(
ck
::
tensor_transform_wrapper
::
get
<
0
>
(
layout
));
d
++
)
{
{
for
(
ck
::
index_t
h
=
0
;
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout
));
h
++
)
h
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
ck
::
tensor_transform_wrapper
::
get
<
0
>
(
layout
));
h
++
)
{
{
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_
wrapper
::
size
<
1
>
(
layout
);
w
++
)
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
wrapper
::
size
<
1
>
(
layout
);
w
++
)
{
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d
,
h
),
w
))
<<
" "
;
std
::
cout
<<
layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d
,
h
),
w
))
<<
" "
;
}
}
...
@@ -68,7 +64,7 @@ int main()
...
@@ -68,7 +64,7 @@ int main()
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
// (dims:4,8 strides:1,4)
// (dims:4,8 strides:1,4)
const
auto
shape_4x8
=
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
8
>
{});
const
auto
shape_4x8
=
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
8
>
{});
const
auto
layout_4x8_s1x4
=
ck
::
tensor_transform_
wrapper
::
make_layout
(
shape_4x8
);
const
auto
layout_4x8_s1x4
=
ck
::
wrapper
::
make_layout
(
shape_4x8
);
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
Print2d
(
layout_4x8_s1x4
);
Print2d
(
layout_4x8_s1x4
);
using
Cord1x1Type
=
ck
::
Tuple
<
ck
::
Number
<
1
>
,
ck
::
Number
<
1
>>
;
using
Cord1x1Type
=
ck
::
Tuple
<
ck
::
Number
<
1
>
,
ck
::
Number
<
1
>>
;
...
@@ -79,8 +75,7 @@ int main()
...
@@ -79,8 +75,7 @@ int main()
// dims:4,(2,4) strides:2,(1,8)
// dims:4,(2,4) strides:2,(1,8)
const
auto
shape_4x2x4
=
ck
::
make_tuple
(
4
,
ck
::
make_tuple
(
2
,
4
));
const
auto
shape_4x2x4
=
ck
::
make_tuple
(
4
,
ck
::
make_tuple
(
2
,
4
));
const
auto
strides_s2x1x8
=
ck
::
make_tuple
(
2
,
ck
::
make_tuple
(
1
,
8
));
const
auto
strides_s2x1x8
=
ck
::
make_tuple
(
2
,
ck
::
make_tuple
(
1
,
8
));
const
auto
layout_4x2x4_s2x1x8
=
const
auto
layout_4x2x4_s2x1x8
=
ck
::
wrapper
::
make_layout
(
shape_4x2x4
,
strides_s2x1x8
);
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x2x4
,
strides_s2x1x8
);
std
::
cout
<<
"dims:4,(2,4) strides:2,(1,8)"
<<
std
::
endl
;
std
::
cout
<<
"dims:4,(2,4) strides:2,(1,8)"
<<
std
::
endl
;
Print2d
(
layout_4x2x4_s2x1x8
);
Print2d
(
layout_4x2x4_s2x1x8
);
...
@@ -92,7 +87,7 @@ int main()
...
@@ -92,7 +87,7 @@ int main()
const
auto
strides_s1x4x2x8
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
const
auto
strides_s1x4x2x8
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
8
>
{}));
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
8
>
{}));
static
const
auto
layout_2x2x2x4_s1x4x2x8
=
static
const
auto
layout_2x2x2x4_s1x4x2x8
=
ck
::
tensor_transform_
wrapper
::
make_layout
(
shape_2x2x2x4
,
strides_s1x4x2x8
);
ck
::
wrapper
::
make_layout
(
shape_2x2x2x4
,
strides_s1x4x2x8
);
std
::
cout
<<
"dims:(2,2),(2,4) strides:(1,4),(2,8)"
<<
std
::
endl
;
std
::
cout
<<
"dims:(2,2),(2,4) strides:(1,4),(2,8)"
<<
std
::
endl
;
Print2d
(
layout_2x2x2x4_s1x4x2x8
);
Print2d
(
layout_2x2x2x4_s1x4x2x8
);
...
@@ -108,7 +103,7 @@ int main()
...
@@ -108,7 +103,7 @@ int main()
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
Number
<
2
>
{}),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
Number
<
2
>
{}),
ck
::
Number
<
8
>
{});
ck
::
Number
<
8
>
{});
static
const
auto
layout_2x2x2x4_s1x4x2x8_nested
=
static
const
auto
layout_2x2x2x4_s1x4x2x8_nested
=
ck
::
tensor_transform_
wrapper
::
make_layout
(
shape_2x2x2x4_nested
,
strides_s1x4x2x8_nested
);
ck
::
wrapper
::
make_layout
(
shape_2x2x2x4_nested
,
strides_s1x4x2x8_nested
);
std
::
cout
<<
"dims:((2,2),2),4 strides:((1,4),2),8"
<<
std
::
endl
;
std
::
cout
<<
"dims:((2,2),2),4 strides:((1,4),2),8"
<<
std
::
endl
;
Print1d
(
layout_2x2x2x4_s1x4x2x8_nested
);
Print1d
(
layout_2x2x2x4_s1x4x2x8_nested
);
...
...
docs/conf.py
View file @
91949d6e
...
@@ -4,23 +4,34 @@
...
@@ -4,23 +4,34 @@
# list see the documentation:
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import
subprocess
import
re
from
rocm_docs
import
ROCmDocs
from
rocm_docs
import
ROCmDocs
html_theme_options
=
{
"flavor"
:
"list"
}
name
=
"Composable Kernel"
with
open
(
'../CMakeLists.txt'
,
encoding
=
'utf-8'
)
as
f
:
get_version
=
r
'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt'
match
=
re
.
search
(
r
'.*set\(version ([0-9.]+)[^0-9.]+'
,
f
.
read
())
version
=
subprocess
.
getoutput
(
get_version
)
if
not
match
:
if
len
(
version
)
>
0
:
raise
ValueError
(
"VERSION not found!"
)
name
=
f
"
{
name
}
{
version
}
"
version_number
=
match
[
1
]
left_nav_title
=
f
"Composable Kernel
{
version_number
}
Documentation"
# for PDF output on Read the Docs
project
=
"Composable Kernel Documentation"
author
=
"Advanced Micro Devices, Inc."
copyright
=
"Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved."
version
=
version_number
release
=
version_number
external_toc_path
=
"./sphinx/_toc.yml"
external_toc_path
=
"./sphinx/_toc.yml"
docs_core
=
ROCmDocs
(
f
"
{
name
}
Documentation"
)
docs_core
=
ROCmDocs
(
left_nav_title
)
docs_core
.
run_doxygen
(
doxygen_root
=
"doxygen"
,
doxygen_path
=
"doxygen/
docBin/
xml"
)
docs_core
.
run_doxygen
(
doxygen_root
=
"doxygen"
,
doxygen_path
=
"doxygen/xml"
)
docs_core
.
setup
()
docs_core
.
setup
()
external_projects_current_project
=
"composable_kernel"
mathjax3_config
=
{
mathjax3_config
=
{
'tex'
:
{
'tex'
:
{
'macros'
:
{
'macros'
:
{
...
...
docs/doxygen/Doxyfile
View file @
91949d6e
...
@@ -58,7 +58,7 @@ PROJECT_LOGO =
...
@@ -58,7 +58,7 @@ PROJECT_LOGO =
# entered, it will be relative to the location where doxygen was started. If
# entered, it will be relative to the location where doxygen was started. If
# left blank the current directory will be used.
# left blank the current directory will be used.
OUTPUT_DIRECTORY =
docBin
OUTPUT_DIRECTORY =
.
# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
# directories (in 2 levels) under the output directory of each output format and
# directories (in 2 levels) under the output directory of each output format and
...
@@ -778,7 +778,9 @@ WARN_LOGFILE =
...
@@ -778,7 +778,9 @@ WARN_LOGFILE =
INPUT = ../../include/ck/tensor_operation/gpu/grid \
INPUT = ../../include/ck/tensor_operation/gpu/grid \
../../include/ck/tensor_operation/gpu/block \
../../include/ck/tensor_operation/gpu/block \
../../include/ck/tensor_operation/gpu/thread \
../../include/ck/tensor_operation/gpu/thread \
../../library/include/ck/library/utility
../../library/include/ck/library/utility \
../../include/ck/wrapper
# This tag can be used to specify the character encoding of the source files
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
...
...
docs/index.rst
View file @
91949d6e
...
@@ -34,6 +34,7 @@ Current CK library are structured into 4 layers:
...
@@ -34,6 +34,7 @@ Current CK library are structured into 4 layers:
* "Templated Tile Operators" layer
* "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer
* "Templated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer
* "Wrapper for tensor transform operations"
* "Client API" layer
* "Client API" layer
.. image:: data/ck_layer.png
.. image:: data/ck_layer.png
...
@@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order:
...
@@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order:
tutorial_hello_world
tutorial_hello_world
dockerhub
dockerhub
wrapper
Supported_Primitives_Guide
Supported_Primitives_Guide
API_Reference_Guide
API_Reference_Guide
Contributors_Guide
Contributors_Guide
docs/sphinx/_toc.yml.in
View file @
91949d6e
...
@@ -5,6 +5,6 @@ defaults:
...
@@ -5,6 +5,6 @@ defaults:
maxdepth: 6
maxdepth: 6
root: index
root: index
subtrees:
subtrees:
- caption: About
- caption: About
entries:
entries:
- file: license
- file: license
docs/sphinx/requirements.in
View file @
91949d6e
rocm-docs-core
>
=0.2
0
.0
rocm-docs-core
=
=0.2
9
.0
sphinxcontrib-bibtex==2.6.1
sphinxcontrib-bibtex==2.6.1
docs/sphinx/requirements.txt
View file @
91949d6e
...
@@ -96,9 +96,7 @@ pygments==2.14.0
...
@@ -96,9 +96,7 @@ pygments==2.14.0
# pydata-sphinx-theme
# pydata-sphinx-theme
# sphinx
# sphinx
pyjwt[crypto]==2.6.0
pyjwt[crypto]==2.6.0
# via
# via pygithub
# pygithub
# pyjwt
pynacl==1.5.0
pynacl==1.5.0
# via pygithub
# via pygithub
pytz==2023.3.post1
pytz==2023.3.post1
...
@@ -113,7 +111,7 @@ requests==2.28.2
...
@@ -113,7 +111,7 @@ requests==2.28.2
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==0.2
7
.0
rocm-docs-core==0.2
9
.0
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
...
docs/wrapper.rst
0 → 100644
View file @
91949d6e
===============
Wrapper
===============
-------------------------------------
Description
-------------------------------------
.. note::
The wrapper is under development and its functionality is limited.
CK provides a lightweight wrapper for more complex operations implemented in
the library. It allows indexing of nested layouts using a simple interface
(avoiding complex descriptor transformations).
Example:
.. code-block:: c
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
{
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
Output::
dims:4,(2,4) strides:2,(1,8)
0 1 8 9 16 17 24 25
2 3 10 11 18 19 26 27
4 5 12 13 20 21 28 29
6 7 14 15 22 23 30 31
-------------------------------------
Layout
-------------------------------------
.. doxygenstruct:: ck::wrapper::Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
example/64_tensor_transforms/CMakeLists.txt
deleted
100644 → 0
View file @
bb819abc
add_example_executable
(
example_tensor_transform tensor_transform.cpp
)
add_example_executable
(
example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp
)
include/ck/utility/tuple_helper.hpp
View file @
91949d6e
...
@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
...
@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
}
template
<
index_t
depth
=
0
,
typename
T
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
T
&
)
{
return
depth
;
}
template
<
index_t
depth
=
0
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
Tuple
<
Ts
...
>&
)
{
return
math
::
max
(
TupleDepth
<
depth
+
1
>
(
Ts
{})...);
}
}
// namespace ck
}
// namespace ck
example/64_tensor_transforms/tensor_transform_wrapper
.hpp
→
include/ck/wrapper/layout
.hpp
View file @
91949d6e
...
@@ -3,27 +3,13 @@
...
@@ -3,27 +3,13 @@
#pragma once
#pragma once
#include "ck/ck.hpp"
#include "ck/wrapper/layout_utils.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_transform_
wrapper
{
namespace
wrapper
{
/**
/**
* \brief Layout wrapper
* \brief Layout wrapper that performs the tensor descriptor logic.
*
* \details
* Layout wrapper that performs the tensor descriptor logic.
*
*
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (dynamic layout). It is possible to pass nested shapes
...
@@ -32,21 +18,19 @@ namespace tensor_transform_wrapper {
...
@@ -32,21 +18,19 @@ namespace tensor_transform_wrapper {
* (dynamic layout). Stride tuple should be nested if shape tuple is
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
* nested.
*/
*/
template
<
typename
Shape
,
typename
Strides
=
Tuple
<
>
>
template
<
typename
Shape
,
typename
Strides
>
struct
Layout
struct
Layout
{
{
private:
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
// Generate packed (column-major) strides if not passed
// Generate packed (column-major) strides if not passed
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
tupl
e
)
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
shap
e
)
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
if
constexpr
(
i
.
value
==
0
)
...
@@ -56,10 +40,10 @@ struct Layout
...
@@ -56,10 +40,10 @@ struct Layout
else
else
{
{
return
TupleReduce
<
I0
.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
return
TupleReduce
<
I0
.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
tupl
e
);
unrolled_shap
e
);
}
}
},
},
Number
<
Tuple
<
Ts
...
>
::
Size
()
>
{});
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
...
@@ -112,7 +96,7 @@ struct Layout
...
@@ -112,7 +96,7 @@ struct Layout
// Example shape: (2, (2, 2)), 2, (2, 2)
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
Unroll
Shape
Via
Idx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
Align
Shape
To
Idx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
const
Tuple
<
IdxDims
...
>&
idx
)
{
{
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
IdxDims
...
>
{}))
if
constexpr
(
!
IsNestedTuple
(
Tuple
<
IdxDims
...
>
{}))
...
@@ -125,7 +109,7 @@ struct Layout
...
@@ -125,7 +109,7 @@ struct Layout
// Iterate over shape tuple elements:
// Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
// 2. If no, pack in tuple. It will be restored during unroll.
auto
unroll
ed_shape
_via_idx
=
generate_tuple
(
auto
align
ed_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_tuple
,
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
...
@@ -140,7 +124,7 @@ struct Layout
...
@@ -140,7 +124,7 @@ struct Layout
Number
<
Tuple
<
IdxDims
...
>::
Size
()
>
{});
Number
<
Tuple
<
IdxDims
...
>::
Size
()
>
{});
// Unroll and process next step
// Unroll and process next step
return
Unroll
Shape
Via
Idx
(
UnrollNestedTuple
<
0
,
1
>
(
unroll
ed_shape
_via_idx
),
return
Align
Shape
To
Idx
(
UnrollNestedTuple
<
0
,
1
>
(
align
ed_shape
),
UnrollNestedTuple
<
0
,
1
>
(
idx
));
UnrollNestedTuple
<
0
,
1
>
(
idx
));
}
}
}
}
...
@@ -150,12 +134,9 @@ struct Layout
...
@@ -150,12 +134,9 @@ struct Layout
DescriptorToMerge
&
desc
)
DescriptorToMerge
&
desc
)
{
{
// Reverse each element in tuple
// Reverse each element in tuple
using
ReversedUnrolledShape
=
decltype
(
TupleReverse
(
UnrollNestedTuple
(
shape
)));
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
));
const
auto
merge_elems
=
ReversedUnrolledShape
{};
// Generate reverted indexes (column major traverse)
// Generate reverted indexes (column major traverse)
using
MergeElemsSequence
=
using
MergeElemsSequence
=
typename
arithmetic_sequence_gen
<
0
,
merge_elems
.
Size
(),
1
>::
type
;
typename
arithmetic_sequence_gen
<
0
,
ReversedUnrolledShape
::
Size
(),
1
>::
type
;
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
// Merge to 1d
...
@@ -163,14 +144,14 @@ struct Layout
...
@@ -163,14 +144,14 @@ struct Layout
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
// Merge nested shape dims
// Merge nested shape dims
. Merge nested shape dims when idx is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
// Merged shape: 2, 4, 2, 4
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
MakeMerges
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
{
{
const
auto
transforms
=
generate_tuple
(
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -224,9 +205,9 @@ struct Layout
...
@@ -224,9 +205,9 @@ struct Layout
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
// Unroll while IdxDims is nested
const
auto
unroll
ed_shape
_via_idx
=
Unroll
Shape
Via
Idx
(
shape
,
idx
);
const
auto
align
ed_shape
=
Align
Shape
To
Idx
(
shape
,
idx
);
// Transform correct form of shape
// Transform correct form of shape
return
Mak
eMerge
s
(
unrolled_shape_via_idx
,
UnrollNestedTuple
(
idx
),
descriptor_
);
return
Creat
eMerge
dDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
),
descriptor_
);
}
}
}
}
...
@@ -235,25 +216,20 @@ struct Layout
...
@@ -235,25 +216,20 @@ struct Layout
const
LayoutStrides
&
strides
)
const
LayoutStrides
&
strides
)
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
if
constexpr
(
ck
::
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
{
// If shape is packed
const
auto
column_major_packed_strides
=
GenerateColumnMajorPackedStrides
(
unrolled_shape
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
column_major_packed_strides
);
}
else
{
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
}
}
public:
public:
using
NaiveDescriptorType
=
remove_cvref_t
<
decltype
(
MakeNaiveDescriptor
(
Shape
{},
Strides
{}))
>
;
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using
DeducedStrides
=
std
::
conditional_t
<
is_same_v
<
Strides
,
Tuple
<>>
,
remove_cvref_t
<
decltype
(
GenerateColumnMajorPackedStrides
(
Shape
{}))
>
,
Strides
>
;
using
NaiveDescriptorType
=
remove_cvref_t
<
decltype
(
MakeNaiveDescriptor
(
Shape
{},
DeducedStrides
{}))
>
;
/**
/**
* \brief Layout constructor.
* \brief Layout constructor.
...
@@ -268,9 +244,9 @@ struct Layout
...
@@ -268,9 +244,9 @@ struct Layout
// Construct if runtime mode
// Construct if runtime mode
if
constexpr
(
!
NaiveDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
NaiveDescriptorType
::
IsKnownAtCompileTime
())
{
{
// Keep only shape, strides are not need for transforms
shape_
=
shape
;
shape_
=
shape
;
descriptor_
=
MakeNaiveDescriptor
(
shape
,
strides
);
strides_
=
strides
;
descriptor_
=
MakeNaiveDescriptor
(
shape_
,
strides_
);
}
}
}
}
...
@@ -279,7 +255,8 @@ struct Layout
...
@@ -279,7 +255,8 @@ struct Layout
if
constexpr
(
!
NaiveDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
NaiveDescriptorType
::
IsKnownAtCompileTime
())
{
{
shape_
=
shape
;
shape_
=
shape
;
descriptor_
=
MakeNaiveDescriptor
(
shape
,
Strides
{});
strides_
=
GenerateColumnMajorPackedStrides
(
shape_
);
descriptor_
=
MakeNaiveDescriptor
(
shape_
,
strides_
);
}
}
}
}
...
@@ -338,7 +315,7 @@ struct Layout
...
@@ -338,7 +315,7 @@ struct Layout
*
*
* \return Calculated size.
* \return Calculated size.
*/
*/
__host__
__device__
constexpr
index_t
GetLength
()
const
__host__
__device__
constexpr
index_t
GetLength
s
()
const
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
...
@@ -346,80 +323,24 @@ struct Layout
...
@@ -346,80 +323,24 @@ struct Layout
}
}
/**
/**
* \brief
Dimension
getter.
* \brief
Shape
getter.
*
*
* \tparam IDim Dimension idx.
* \return Shape.
* \return Calculated size.
*/
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
Shape
GetShape
()
const
{
return
shape_
;
}
__host__
__device__
constexpr
auto
Get
()
const
{
/**
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
* \brief Strides getter.
return
elem
;
*
}
* \return Strides.
*/
__host__
__device__
constexpr
DeducedStrides
GetStrides
()
const
{
return
strides_
;
}
private:
private:
NaiveDescriptorType
descriptor_
;
NaiveDescriptorType
descriptor_
;
Shape
shape_
;
Shape
shape_
;
DeducedStrides
strides_
;
};
};
// Layout helpers
}
// namespace wrapper
// Length getter (product if tuple)
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
// Get shape size (product of dims if tuple)
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
using
UnrolledShape
=
decltype
(
UnrollNestedTuple
(
shape
));
return
TupleReduce
<
0
,
UnrolledShape
::
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledShape
{});
}
// Get dim size (could be returned from get function)
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
{
return
dim
;
}
// Get layout size (product of shapes)
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetLength
();
}
// Get shape element size
template
<
index_t
idx
,
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
return
size
(
shape
.
At
(
Number
<
idx
>
{}));
}
// Dim getter (tuple if tuple)
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
template
Get
<
idx
>();
}
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
}
template
<
typename
Shape
>
__host__
__device__
constexpr
Layout
<
Shape
>
make_layout
(
const
Shape
&
shape
)
{
return
Layout
<
Shape
>
(
shape
);
}
}
// namespace tensor_transform_wrapper
}
// namespace ck
}
// namespace ck
include/ck/wrapper/layout_utils.hpp
0 → 100644
View file @
91949d6e
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
namespace
wrapper
{
// Disable from doxygen docs generation
/// @cond
// forward declaration
template
<
typename
Shape
,
typename
Strides
=
Tuple
<
>
>
struct
Layout
;
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
/// @endcond
// make_*
/**
* \brief Make layout function.
*
* \tparam Shape Shape for layout.
* \tparam Strides Strides for layout.
* \return Constructed layout.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
}
/**
* \brief Make layout function with packed strides
* (column-major).
*
* \tparam Shape Shape for layout.
* \return Constructed layout.
*/
template
<
typename
Shape
>
__host__
__device__
constexpr
Layout
<
Shape
>
make_layout
(
const
Shape
&
shape
)
{
return
Layout
<
Shape
>
(
shape
);
}
// Layout helpers
// get
/**
* \brief Get element from tuple (Shape/Strides/Idxs).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted element.
*/
template
<
index_t
idx
,
typename
...
Dims
>
__host__
__device__
constexpr
auto
get
(
const
Tuple
<
Dims
...
>&
tuple
)
{
return
tuple
.
At
(
Number
<
idx
>
{});
}
/**
* \brief Get sub layout.
*
* \tparam idx Index to lookup.
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
const
auto
new_shape
=
get
<
idx
>
(
layout
.
GetShape
());
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_shape
)
>::
value
,
"Shape of sub layout must be tuple"
);
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
// If stride not passed, create without strides
return
make_layout
(
new_shape
);
}
else
{
const
auto
new_strides
=
get
<
idx
>
(
layout
.
GetStrides
());
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_strides
)
>::
value
,
"Strides of sub layout must be tuple"
);
return
make_layout
(
new_shape
,
new_strides
);
}
}
/**
* \brief Hierarchical get.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template
<
index_t
Idx
,
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
get
(
const
T
&
elem
)
{
return
get
<
Idxs
...
>
(
get
<
Idx
>
(
elem
));
}
// size
/**
* \brief Length get (product if tuple).
*
* \tparam idx Index to lookup.
* \param layout Layout to get Shape.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
/**
* \brief Shape size (product of dims).
*
* \param shape Shape to lookup.
* \return Requsted size.
*/
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
TupleReduce
<
0
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
// Get dim size (could be returned from get function)
/**
* \private
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
{
return
dim
;
}
/**
* \brief Layout size (product of dims).
*
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetLengths
();
}
/**
* \brief Length get from tuple (product if tuple).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
...
Ts
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
size
(
tuple
.
At
(
Number
<
idx
>
{}));
}
/**
* \brief Hierarchical size.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
size
(
const
T
&
elem
)
{
return
size
(
get
<
Idxs
...
>
(
elem
));
}
// rank
/**
* \brief Get layout rank (num elements in shape).
*
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
Shape
::
Size
();
}
/**
* \brief Get tuple rank (num elements in tuple).
* Return 1 if scalar passed.
*
* \param tuple Tuple to calculate rank.
* \return Requsted rank.
*/
template
<
typename
...
Dims
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Tuple
<
Dims
...
>&
tuple
)
{
return
Tuple
<
Dims
...
>::
Size
();
}
/**
* \private
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
rank
(
const
Number
<
IDim
>&
)
{
return
1
;
}
/**
* \private
*/
__host__
__device__
constexpr
index_t
rank
(
const
index_t
&
)
{
return
1
;
}
/**
* \brief Hierarchical rank.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted rank.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
rank
(
const
T
&
elem
)
{
return
rank
(
get
<
Idxs
...
>
(
elem
));
}
// depth
/**
* \brief Get depth of the layout shape (return 0 if scalar).
*
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
TupleDepth
(
layout
.
GetShape
());
}
/**
* \brief Get depth of the tuple. (return 0 if scalar)
*
* \param tuple Tuple to calculate depth.
* \return Requsted depth.
*/
template
<
typename
...
Dims
>
__host__
__device__
constexpr
auto
depth
(
const
Tuple
<
Dims
...
>&
tuple
)
{
return
TupleDepth
(
tuple
);
}
/**
* \private
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
depth
(
const
Number
<
IDim
>&
)
{
return
0
;
}
/**
* \private
*/
__host__
__device__
constexpr
index_t
depth
(
const
index_t
&
)
{
return
0
;
}
/**
* \brief Hierarchical depth.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted depth.
*/
template
<
index_t
...
Idxs
,
typename
T
>
__host__
__device__
constexpr
auto
depth
(
const
T
&
elem
)
{
return
depth
(
get
<
Idxs
...
>
(
elem
));
}
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides.
* \return Requsted strides.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
stride
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetStrides
();
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape.
* \return Requsted shape.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
shape
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetShape
();
}
}
// namespace wrapper
}
// namespace ck
test/CMakeLists.txt
View file @
91949d6e
...
@@ -148,6 +148,7 @@ add_subdirectory(batched_gemm_multi_d)
...
@@ -148,6 +148,7 @@ add_subdirectory(batched_gemm_multi_d)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
conv_tensor_rearrange
)
add_subdirectory
(
conv_tensor_rearrange
)
add_subdirectory
(
transpose
)
add_subdirectory
(
transpose
)
add_subdirectory
(
wrapper
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
test/wrapper/CMakeLists.txt
0 → 100644
View file @
91949d6e
add_gtest_executable
(
test_layout test_layout.cpp
)
target_link_libraries
(
test_layout PRIVATE utility
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment