Commit e5ebcc41 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-migraphx

parents 57cdd70b abac8b07
rocm-docs-core>=0.20.0 rocm-docs-core==0.34.2
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.6.2
...@@ -16,7 +16,7 @@ beautifulsoup4==4.11.2 ...@@ -16,7 +16,7 @@ beautifulsoup4==4.11.2
# via pydata-sphinx-theme # via pydata-sphinx-theme
breathe==4.34.0 breathe==4.34.0
# via rocm-docs-core # via rocm-docs-core
certifi==2022.12.7 certifi==2023.7.22
# via requests # via requests
cffi==1.15.1 cffi==1.15.1
# via # via
...@@ -26,7 +26,7 @@ charset-normalizer==3.1.0 ...@@ -26,7 +26,7 @@ charset-normalizer==3.1.0
# via requests # via requests
click==8.1.3 click==8.1.3
# via sphinx-external-toc # via sphinx-external-toc
cryptography==40.0.2 cryptography==41.0.6
# via pyjwt # via pyjwt
deprecated==1.2.13 deprecated==1.2.13
# via pygithub # via pygithub
...@@ -42,12 +42,18 @@ fastjsonschema==2.18.0 ...@@ -42,12 +42,18 @@ fastjsonschema==2.18.0
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.10
# via gitpython # via gitpython
gitpython==3.1.31 gitpython==3.1.37
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.4
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
importlib-metadata==6.8.0
# via
# sphinx
# sphinxcontrib-bibtex
importlib-resources==6.1.0
# via rocm-docs-core
jinja2==3.1.2 jinja2==3.1.2
# via # via
# myst-parser # myst-parser
...@@ -82,28 +88,32 @@ pydata-sphinx-theme==0.13.3 ...@@ -82,28 +88,32 @@ pydata-sphinx-theme==0.13.3
# via # via
# rocm-docs-core # rocm-docs-core
# sphinx-book-theme # sphinx-book-theme
pygithub==1.58.2 pygithub==1.58.1
# via rocm-docs-core # via rocm-docs-core
pygments==2.14.0 pygments==2.15.0
# via # via
# accessible-pygments # accessible-pygments
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt[crypto]==2.6.0 pyjwt[crypto]==2.6.0
# via pygithub # via
# pygithub
# pyjwt
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pytz==2023.3.post1
# via babel
pyyaml==6.0 pyyaml==6.0
# via # via
# myst-parser # myst-parser
# pybtex # pybtex
# rocm-docs-core # rocm-docs-core
# sphinx-external-toc # sphinx-external-toc
requests==2.28.2 requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core>=0.20.0 rocm-docs-core==0.34.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
...@@ -131,7 +141,7 @@ sphinx-book-theme==1.0.1 ...@@ -131,7 +141,7 @@ sphinx-book-theme==1.0.1
# via rocm-docs-core # via rocm-docs-core
sphinx-copybutton==0.5.1 sphinx-copybutton==0.5.1
# via rocm-docs-core # via rocm-docs-core
sphinx-design==0.3.0 sphinx-design==0.4.1
# via rocm-docs-core # via rocm-docs-core
sphinx-external-toc==0.3.1 sphinx-external-toc==0.3.1
# via rocm-docs-core # via rocm-docs-core
...@@ -139,7 +149,7 @@ sphinx-notfound-page==0.8.3 ...@@ -139,7 +149,7 @@ sphinx-notfound-page==0.8.3
# via rocm-docs-core # via rocm-docs-core
sphinxcontrib-applehelp==1.0.4 sphinxcontrib-applehelp==1.0.4
# via sphinx # via sphinx
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.6.2
# via -r requirements.in # via -r requirements.in
sphinxcontrib-devhelp==1.0.2 sphinxcontrib-devhelp==1.0.2
# via sphinx # via sphinx
...@@ -153,7 +163,11 @@ sphinxcontrib-serializinghtml==1.1.5 ...@@ -153,7 +163,11 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx # via sphinx
typing-extensions==4.5.0 typing-extensions==4.5.0
# via pydata-sphinx-theme # via pydata-sphinx-theme
urllib3==1.26.15 urllib3==1.26.18
# via requests # via requests
wrapt==1.15.0 wrapt==1.15.0
# via deprecated # via deprecated
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources
=============== .. meta::
CK Hello world :description: Composable Kernel documentation and API reference library
=============== :keywords: composable kernel, CK, ROCm, API, documentation
------------------------------------- .. _hello-world:
Motivation
-------------------------------------
This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who ********************************************************************
would like to optimize their pipelines and squeeze every performance drop by adding Composable Hello World Tutorial
Kernel (CK) library to their projects. We would like to make the CK library approachable so ********************************************************************
the tutorial is not based on the latest release and doesn't have all the bleeding edge features,
but it will be reproducible now and forever.
During this tutorial we will have an introduction to the CK library, we will build it and run some This tutorial is for engineers dealing with artificial intelligence and machine learning who
examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go would like to optimize pipelines and improve performance using the Composable
in depth and breadth and get familiar with other tools and ways to integrate CK into your project. Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example.
-------------------------------------
Description Description
------------------------------------- ===========
Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and Modern AI technology solves more and more problems in a variety of fields, but crafting fast and
efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast efficient workflows is still challenging. CK can make the AI workflow fast
and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create and efficient. CK is a collection of optimized AI operator kernels with tools to create
new ones. The library has components required for majority of modern neural networks architectures new kernels. The library has components required for modern neural network architectures
including matrix multiplication, convolution, contraction, reduction, attention modules, variety of including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators.
activation functions, fused operators and many more.
So how do we (almost) reach the speed of light? CK acceleration abilities are based on: CK library acceleration features are based on:
* Layered structure. * Layered structure
* Tile-based computation model. * Tile-based computation model
* Tensor coordinate transformation. * Tensor coordinate transformation
* Hardware acceleration use. * Hardware acceleration use
* Support of low precision data types including fp16, bf16, int8 and int4. * Support of low precision data types including fp16, bf16, int8 and int4
If you are excited and need more technical details and benchmarking results - read this awesome If you need more technical details and benchmarking results read the following
`blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_. `blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_.
For more details visit our `github repository <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_. To download the library visit the `composable_kernel repository <https://github.com/ROCm/composable_kernel>`_.
-------------------------------------
Hardware targets Hardware targets
------------------------------------- ================
CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are
supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture.
GPU architecture.
========== ========= ========== =========
GPU Target AMD GPU GPU Target AMD GPU
...@@ -59,47 +51,24 @@ gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6 ...@@ -59,47 +51,24 @@ gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6
There are also `cloud options <https://aws.amazon.com/ec2/instance-types/g4/>`_ you can find if There are also `cloud options <https://aws.amazon.com/ec2/instance-types/g4/>`_ you can find if
you don't have an AMD GPU at hand. you don't have an AMD GPU at hand.
-------------------------------------
Build the library Build the library
------------------------------------- =================
First let's clone the library and rebase to the tested version:: This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point.
git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git .. note::
cd composable_kernel/
git checkout tutorial_hello_world
To make our lives easier we prepared
`docker images <https://hub.docker.com/r/rocm/composable_kernel>`_ with all the necessary
dependencies. Pick the right image and create a container. In this tutorial we use
``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and
ROCm v5.6.
If your current folder is ``${HOME}``, start the docker container with::
docker run \
-it \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${HOME}:/root/workspace \
rocm/composable_kernel:ck_ub20.04_rocm5.6 \
/bin/bash
If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` You can also `install ROCm <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/>`_ on your system, clone the `Composable Kernel repository <https://github.com/ROCm/composable_kernel.git>`_ on GitHub, and use that to build and run the examples using the commands described below.
to fit your folder structure.
Inside the docker container current folder is ``~/workspace``, library path is Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library::
``~/workspace/composable_kernel``, navigate to the library::
cd composable_kernel/ cd composable_kernel/
Create and go to the ``build`` directory:: Create and change to a ``build`` directory::
mkdir build && cd build mkdir build && cd build
In the previous section we talked about target GPU architecture. Once you decide which one is right The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag::
for you, run CMake using the right ``GPU_TARGETS`` flag::
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
...@@ -109,26 +78,25 @@ for you, run CMake using the right ``GPU_TARGETS`` flag:: ...@@ -109,26 +78,25 @@ for you, run CMake using the right ``GPU_TARGETS`` flag::
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -D GPU_TARGETS="gfx908;gfx90a;gfx1030" ..
If everything went well the CMake run will end up with:: If everything goes well the CMake command will return::
-- Configuring done -- Configuring done
-- Generating done -- Generating done
-- Build files have been written to: "/root/workspace/composable_kernel/build" -- Build files have been written to: "/root/workspace/composable_kernel/build"
Finally, we can build examples and tests:: Finally, you can build examples and tests::
make -j examples tests make -j examples tests
If everything is smooth, you'll see:: When complete you should see::
Scanning dependencies of target tests Scanning dependencies of target tests
[100%] Built target tests [100%] Built target tests
---------------------------
Run examples and tests Run examples and tests
--------------------------- ======================
Examples are listed as test cases as well, so we can run all examples and tests with:: Examples are listed as test cases as well, so you can run all examples and tests with::
ctest ctest
...@@ -136,38 +104,32 @@ You can check the list of all tests by running:: ...@@ -136,38 +104,32 @@ You can check the list of all tests by running::
ctest -N ctest -N
We can also run them separately, here is a separate example execution:: You can also run examples separately as shown in the following example execution::
./bin/example_gemm_xdl_fp16 1 1 1 ./bin/example_gemm_xdl_fp16 1 1 1
The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU, The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change.
initialize matrices with integers and benchmark the kernel execution. You can play around with
these parameters and see how output and execution results change.
If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like::
something like::
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle<Default, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, 8, 4, 1, 2> LoopScheduler: Interwave, PipelineVersion: v1
Warm up 1 time
Start running 10 times...
Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1
Meanwhile, running it on a `gfx1030` device should result in:: However, running it on a `gfx1030` device should result in the following::
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem
But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a Don't worry, some operators are supported on `gfx1030` architecture, so you can run a
separate example like:: separate example like::
./bin/example_gemm_dl_fp16 1 1 1 ./bin/example_gemm_dl_fp16 1 1 1
and it should result in something nice similar to:: and it should return something like::
a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
...@@ -182,12 +144,9 @@ and it should result in something nice similar to:: ...@@ -182,12 +144,9 @@ and it should result in something nice similar to::
.. note:: .. note::
There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture.
the newest versions of the library and do not see the above results when running
``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command
in order to build the operators supported on the `gfx1030` architecture.
We can also run a separate test:: You can also run a separate test::
ctest -R test_gemm_fp16 ctest -R test_gemm_fp16
...@@ -198,13 +157,9 @@ If everything goes well you should see something like:: ...@@ -198,13 +157,9 @@ If everything goes well you should see something like::
100% tests passed, 0 tests failed out of 1 100% tests passed, 0 tests failed out of 1
-----------
Summary Summary
----------- =======
In this tutorial we took the first look at the Composable Kernel library, built it on your system In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task.
and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different
configs to find out the best one for your hardware and task.
P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance.
ways to spend your money for sure!
.. meta::
:description: Composable Kernel documentation and API reference library
:keywords: composable kernel, CK, ROCm, API, documentation
.. _what-is-ck:
********************************************************************
What is the Composable Kernel library
********************************************************************
Methodology
===========
The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++.
CK utilizes two concepts to achieve performance portability and code maintainability:
* A tile-based programming model
* Algorithm complexity reduction for complex ML operators using an innovative technique called
"Tensor Coordinate Transformation".
.. image:: data/ck_component.png
:alt: CK Components
Code Structure
==============
The CK library is structured into 4 layers:
* "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer
* "Client API" layer
It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code.
.. image:: data/ck_layer.png
:alt: CK Layers
\ No newline at end of file
.. meta::
:description: Composable Kernel documentation and API reference library
:keywords: composable kernel, CK, ROCm, API, documentation
.. _wrapper:
********************************************************************
Wrapper
********************************************************************
-------------------------------------
Description
-------------------------------------
The CK library provides a lightweight wrapper for more complex operations implemented in
the library.
Example:
.. code-block:: c
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::array<ck::index_t, 32> data;
auto tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(&data[0], layout);
for(ck::index_t w = 0; w < size(tensor); w++) {
tensor(w) = w;
}
// slice() == slice(0, -1) (whole dimension)
auto tensor_slice = tensor(ck::wrapper::slice(1, 3), ck::make_tuple(ck::wrapper::slice(), ck::wrapper::slice()));
std::cout << "dims:2,(2,4) strides:2,(1,8)" << std::endl;
for(ck::index_t h = 0; h < ck::wrapper::size<0>(tensor_slice); h++)
{
for(ck::index_t w = 0; w < ck::wrapper::size<1>(tensor_slice); w++)
{
std::cout << tensor_slice(h, w) << " ";
}
std::cout << std::endl;
}
Output::
dims:2,(2,4) strides:2,(1,8)
1 5 9 13 17 21 25 29
2 6 10 14 18 22 26 30
Tutorials:
* `GEMM tutorial <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/README.md>`_
Advanced examples:
* `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_
* `Basic gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp>`_
* `Optimized gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp>`_
-------------------------------------
Layout
-------------------------------------
.. doxygenstruct:: ck::wrapper::Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
-------------------------------------
Tensor
-------------------------------------
.. doxygenstruct:: ck::wrapper::Tensor
-------------------------------------
Tensor helpers
-------------------------------------
.. doxygenfile:: tensor_utils.hpp
.. doxygenfile:: tensor_partition.hpp
-------------------------------------
Operations
-------------------------------------
.. doxygenfile:: copy.hpp
.. doxygenfile:: gemm.hpp
if(DL_KERNELS) add_custom_target(example_gemm_dl)
add_custom_target(example_gemm_dl)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_dependencies(example_gemm_dl example_gemm_dl_fp32)
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16)
endif() add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_dependencies(example_gemm_dl example_gemm_dl_int8)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_gemm_dl example_gemm_dl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int4) add_example_dependencies(example_gemm_dl example_gemm_dl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
endif()
add_custom_target(example_gemm_xdl) add_custom_target(example_gemm_xdl)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
endif() add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) # FIXME: re-enable this example as test when SWDEV-335738 is fixed
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
endif()
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_f8) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
endif() add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
endif()
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp" #include "common.hpp"
...@@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
#endif
\ No newline at end of file
...@@ -3,31 +3,33 @@ ...@@ -3,31 +3,33 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp"
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using ALayout = Col; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>;
// clang-format on // // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = float;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -30,7 +30,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl ...@@ -30,7 +30,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>;
// // clang-format on // // clang-format on
// clang-format off // clang-format off
...@@ -39,7 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl ...@@ -39,7 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>;
// clang-format on // clang-format on
using DeviceGemmInstance = DeviceGemmInstance1; using DeviceGemmInstance = DeviceGemmInstance1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using F32 = float;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV2<
ALayout, BLayout, CLayout,
F16, F16, F16, F32, F16,
PassThrough, PassThrough, PassThrough, GemmDefault,
2, 256,
256, 256,
32, 8, 4,
32, 32,
4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 4, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
using ADataType = ck::f8_t; using ADataType = ck::f8_t;
using BDataType = ck::f8_t; using BDataType = ck::f8_t;
using CDataType = ck::f8_t; using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::f8_t; using CShuffleDataType = float;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::bf8_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::bf8_t;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeTypeA,
ComputeTypeB>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp" #include "common.hpp"
...@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
#endif
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#if USING_DIRECT_LOADS
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#else
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
using F32 = float;
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F32;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#if USING_DIRECT_LOADS
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#else
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0) if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
...@@ -16,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -16,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
struct AlphaBetaAdd
{
AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& e, const std::int32_t& c, const std::int8_t& d) const
{
e = ck::type_convert<std::int8_t>(alpha_ * c + beta_ * ck::type_convert<std::int32_t>(d));
};
int alpha_;
int beta_;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = std::int8_t;
using I32 = std::int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using DDataType = I8;
using EDataType = I8;
using ALayout = Row;
using BLayout = Row;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
32,
16,
16,
4,
16,
16,
16,
1,
1,
S<2, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
1,
S<4, 1, 8>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
16,
2,
1,
1,
1,
S<1, 16, 1, 2>,
8>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
int alpha = 1;
int beta = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]);
beta = std::stof(argv[5]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
alpha = std::stof(argv[11]);
beta = std::stof(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment