"test/vscode:/vscode.git/clone" did not exist on "ae9560dabb462708fadcb9d5387c655e32a5399a"
Unverified Commit 9de63596 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #48 from ROCm/merge_from_public

Merge from public
parents e60c5aea 6ac1d6a2
...@@ -17,7 +17,7 @@ Getting started ...@@ -17,7 +17,7 @@ Getting started
`Composable Kernel User Guide <https://rocm.docs.amd.com/projects/composable_kernel/en/latest/>`_. `Composable Kernel User Guide <https://rocm.docs.amd.com/projects/composable_kernel/en/latest/>`_.
It provides insight into the core concepts, environment configuration, and steps to obtain or It provides insight into the core concepts, environment configuration, and steps to obtain or
build the library. You can also find some of this information in the build the library. You can also find some of this information in the
`README file <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/README.md>`_ `README file <https://github.com/ROCm/composable_kernel/blob/develop/README.md>`_
on the project's GitHub page. on the project's GitHub page.
#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ provides a deeper understanding of the CK library and showcases its performance capabilities. #. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ provides a deeper understanding of the CK library and showcases its performance capabilities.
<https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_
...@@ -33,7 +33,7 @@ You can make an impact by reporting issues or proposing code enhancements throug ...@@ -33,7 +33,7 @@ You can make an impact by reporting issues or proposing code enhancements throug
Reporting issues Reporting issues
---------------- ----------------
Use `Github issues <https://github.com/ROCmSoftwarePlatform/composable_kernel/issues>`_ Use `Github issues <https://github.com/ROCm/composable_kernel/issues>`_
to track public bugs and enhancement requests. to track public bugs and enhancement requests.
If you encounter an issue with the library, please check if the problem has already been If you encounter an issue with the library, please check if the problem has already been
...@@ -68,7 +68,7 @@ Creating Pull Requests ...@@ -68,7 +68,7 @@ Creating Pull Requests
---------------------- ----------------------
You can submit `Pull Requests (PR) on GitHub You can submit `Pull Requests (PR) on GitHub
<https://github.com/ROCmSoftwarePlatform/composable_kernel/pulls>`_. <https://github.com/ROCm/composable_kernel/pulls>`_.
All contributors are required to develop their changes on a separate branch and then create a All contributors are required to develop their changes on a separate branch and then create a
pull request to merge their changes into the `develop` branch, which is the default pull request to merge their changes into the `develop` branch, which is the default
...@@ -89,7 +89,7 @@ When submitting a Pull Request you should: ...@@ -89,7 +89,7 @@ When submitting a Pull Request you should:
the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We
highly recommend contributors utilize this method to maintain consistent code formatting. highly recommend contributors utilize this method to maintain consistent code formatting.
Instructions on setting up `pre-commit` can be found in the project's Instructions on setting up `pre-commit` can be found in the project's
`README file <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/README.md>`_ `README file <https://github.com/ROCm/composable_kernel/blob/develop/README.md>`_
* Link your PR to any related issues: * Link your PR to any related issues:
......
...@@ -45,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS: ...@@ -45,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS:
extensions += ['sphinxcontrib.bibtex'] extensions += ['sphinxcontrib.bibtex']
bibtex_bibfiles = ['refs.bib'] bibtex_bibfiles = ['refs.bib']
cpp_id_attributes = ["__global__", "__device__", "__host__"]
...@@ -36,9 +36,9 @@ What is inside the image? ...@@ -36,9 +36,9 @@ What is inside the image?
The docker images have everything you need for running CK including: The docker images have everything you need for running CK including:
* `ROCm <https://www.amd.com/en/graphics/servers-solutions-rocm>`_ * `ROCm <https://rocm.docs.amd.com/en/latest/index.html>`_
* `CMake <https://cmake.org/getting-started/>`_ * `CMake <https://cmake.org/getting-started/>`_
* `Compiler <https://github.com/RadeonOpenCompute/llvm-project>`_ * `Compiler <https://github.com/ROCm/llvm-project>`_
* `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_ * `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_
Running the docker container Running the docker container
...@@ -97,5 +97,5 @@ Editing the docker image ...@@ -97,5 +97,5 @@ Editing the docker image
======================= =======================
If you want to customize the docker image, edit the If you want to customize the docker image, edit the
`Dockerfile <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile>`_ `Dockerfile <https://github.com/ROCm/composable_kernel/blob/develop/Dockerfile>`_
from the GitHub repository to suit your needs. from the GitHub repository to suit your needs.
rocm-docs-core==0.33.2 rocm-docs-core==0.35.1
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -113,7 +113,7 @@ requests==2.31.0 ...@@ -113,7 +113,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.33.2 rocm-docs-core==0.35.1
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
...@@ -32,7 +32,7 @@ CK library acceleration features are based on: ...@@ -32,7 +32,7 @@ CK library acceleration features are based on:
If you need more technical details and benchmarking results read the following If you need more technical details and benchmarking results read the following
`blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_. `blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_.
To download the library visit the `composable_kernel repository <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_. To download the library visit the `composable_kernel repository <https://github.com/ROCm/composable_kernel>`_.
Hardware targets Hardware targets
================ ================
...@@ -58,7 +58,7 @@ This tutorial is based on the use of docker images as explained in :ref:`docker- ...@@ -58,7 +58,7 @@ This tutorial is based on the use of docker images as explained in :ref:`docker-
.. note:: .. note::
You can also `install ROCm <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/>`_ on your system, clone the `Composable Kernel repository <https://github.com/ROCmSoftwarePlatform/composable_kernel.git>`_ on GitHub, and use that to build and run the examples using the commands described below. You can also `install ROCm <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/>`_ on your system, clone the `Composable Kernel repository <https://github.com/ROCm/composable_kernel.git>`_ on GitHub, and use that to build and run the examples using the commands described below.
Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library::
......
...@@ -12,10 +12,6 @@ Wrapper ...@@ -12,10 +12,6 @@ Wrapper
Description Description
------------------------------------- -------------------------------------
.. note::
The wrapper is under development and its functionality is limited.
The CK library provides a lightweight wrapper for more complex operations implemented in The CK library provides a lightweight wrapper for more complex operations implemented in
the library. the library.
...@@ -54,39 +50,45 @@ Output:: ...@@ -54,39 +50,45 @@ Output::
2 6 10 14 18 22 26 30 2 6 10 14 18 22 26 30
Tutorials:
* `GEMM tutorial <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/README.md>`_
Advanced examples: Advanced examples:
* `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_ * `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_
* `Basic gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp>`_
* `Optimized gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp>`_
------------------------------------- -------------------------------------
Layout Layout
------------------------------------- -------------------------------------
.. doxygenstruct:: ck::wrapper::Layout .. doxygenstruct:: Layout
------------------------------------- -------------------------------------
Layout helpers Layout helpers
------------------------------------- -------------------------------------
.. doxygenfile:: layout_utils.hpp .. doxygenfile:: include/ck/wrapper/utils/layout_utils.hpp
------------------------------------- -------------------------------------
Tensor Tensor
------------------------------------- -------------------------------------
.. doxygenstruct:: ck::wrapper::Tensor .. doxygenstruct:: Tensor
------------------------------------- -------------------------------------
Tensor helpers Tensor helpers
------------------------------------- -------------------------------------
.. doxygenfile:: tensor_utils.hpp .. doxygenfile:: include/ck/wrapper/utils/tensor_utils.hpp
.. doxygenfile:: tensor_partition.hpp .. doxygenfile:: include/ck/wrapper/utils/tensor_partition.hpp
------------------------------------- -------------------------------------
Operations Operations
------------------------------------- -------------------------------------
.. doxygenfile:: copy.hpp .. doxygenfile:: include/ck/wrapper/operations/copy.hpp
.. doxygenfile:: gemm.hpp .. doxygenfile:: include/ck/wrapper/operations/gemm.hpp
...@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) ...@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
...@@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) ...@@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endforeach() endforeach()
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
...@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final ...@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -19,15 +19,50 @@ using AElementOp = PassThrough; ...@@ -19,15 +19,50 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| < ALayout,
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| BLayout,
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| CLayout,
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ADataType,
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeType>;
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
......
...@@ -20,14 +20,18 @@ using BElementOp = PassThrough; ...@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t; ...@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on // clang-format on
......
...@@ -5,6 +5,88 @@ ...@@ -5,6 +5,88 @@
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType> template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
...@@ -68,9 +150,25 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -68,9 +150,25 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
default: case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
} }
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
...@@ -240,7 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -240,7 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); return ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#endif #endif
} }
......
...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; ...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout, ALayout,
BLayout, BLayout,
ck::Tuple<DLayout>, ck::Tuple<DLayout>,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, AccDataType,
EDataType, CShuffleDataType,
AccDataType, ck::Tuple<DDataType>,
CShuffleDataType, EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
256, 2, // Prefetch stage
128, 128, // BlockSize
256, 128, // MPerBlock
8, 64, // NPerBlock
8, 64, // KPerBlock
16, 8, // K1
16, 16, // MPerWmma
4, 16, // NPerWmma
4, 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
S<4, 64, 1>, 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<1, 0, 2>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
S<4, 64, 1>, true,
S<1, 0, 2>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
1, true,
1, 1, // C shuffle (M Repeat) Per store
S<1, 32, 1, 8>, 1, // C shuffle (N Repeat) Per store
8>; S<1, 32, 1, 4>,
8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -264,7 +265,7 @@ int main(int argc, char* argv[]) ...@@ -264,7 +265,7 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
...@@ -55,7 +55,7 @@ using DDataType = I8; ...@@ -55,7 +55,7 @@ using DDataType = I8;
using EDataType = I8; using EDataType = I8;
using ALayout = Row; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using DLayout = Row; using DLayout = Row;
using ELayout = Row; using ELayout = Row;
...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; ...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout, ALayout,
BLayout, BLayout,
ck::Tuple<DLayout>, ck::Tuple<DLayout>,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, AccDataType,
EDataType, CShuffleDataType,
AccDataType, ck::Tuple<DDataType>,
CShuffleDataType, EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
32, 2, // Prefetch stage
16, 128, // BlockSize
16, 128, // MPerBlock
4, 64, // NPerBlock
16, 64, // KPerBlock
16, 8, // K1
16, 16, // MPerWmma
1, 16, // NPerWmma
1, 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
S<2, 16, 1>, 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<1, 0, 2>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
16, 2,
16, 8,
1, 8,
S<4, 1, 8>, true,
S<0, 2, 1>, S<4, 32, 1>,
S<0, 2, 1>, S<1, 0, 2>,
1, S<1, 0, 2>,
16, 2,
2, 8,
1, 8,
1, true,
1, 1, // C shuffle (M Repeat) Per store
S<1, 16, 1, 2>, 1, // C shuffle (N Repeat) Per store
8>; S<1, 32, 1, 4>,
8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif() endif()
...@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add; using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN = using DeviceOpInstanceKKNN =
...@@ -55,43 +56,44 @@ using DeviceOpInstanceKKNN = ...@@ -55,43 +56,44 @@ using DeviceOpInstanceKKNN =
NumDimK, NumDimK,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
EDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType,
EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
ABSpec, ASpec,
ABSpec, BSpec,
DESpec, DESpec,
256, 1,
128, 128,
256, 64,
8, 64,
8, 64,
4,
16, 16,
16, 16,
1,
4, 4,
4, S<4, 32, 1>,
S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
S<4, 64, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
1, 1,
1, 1,
S<1, 32, 1, 8>, S<1, 64, 1, 2>,
8>; 8>;
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
...@@ -251,21 +253,6 @@ int main(int argc, char* argv[]) ...@@ -251,21 +253,6 @@ int main(int argc, char* argv[])
ck::index_t K0 = 2048; ck::index_t K0 = 2048;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -276,13 +263,43 @@ int main(int argc, char* argv[]) ...@@ -276,13 +263,43 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
G0 = std::stoi(argv[4]);
G1 = std::stoi(argv[5]);
M0 = std::stoi(argv[6]);
M1 = std::stoi(argv[7]);
N0 = std::stoi(argv[8]);
N1 = std::stoi(argv[9]);
K0 = std::stoi(argv[10]);
}
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
exit(0); exit(0);
} }
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
......
...@@ -42,41 +42,42 @@ using DeviceConvFwdInstance = ...@@ -42,41 +42,42 @@ using DeviceConvFwdInstance =
OutputLayout<NDimSpatial>, OutputLayout<NDimSpatial>,
InKernelDataType, InKernelDataType,
WeiKernelDataType, WeiKernelDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
ConvSpec, // ConvForwardSpecialization ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
256, // BlockSize 1, // Prefetch stage
128, // MPerBlock 128, // BlockSize
128, // NPerBlock 64, // MPerBlock
4, // K0PerBlock 64, // NPerBlock
64, // KPerBlock
8, // K1 8, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
2, // NRepeat 1, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1 8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
4, 1,
2, 1,
S<1, 32, 1, 8>, S<1, 16, 1, 8>,
8>; 8>;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
...@@ -277,9 +278,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) ...@@ -277,9 +278,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_) switch(conv_param.num_dim_spatial_)
{ {
case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
} }
return false; return false;
......
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
...@@ -20,4 +29,3 @@ add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_sc ...@@ -20,4 +29,3 @@ add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_sc
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment