"example/vscode:/vscode.git/clone" did not exist on "d4d1147f0ac473b48c2e3ca4a2a21087f1962ede"
Commit 56de337f authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 41b920e2 687d2b7e
...@@ -12,27 +12,26 @@ The Composable Kernel (CK) library provides a programming model for writing perf ...@@ -12,27 +12,26 @@ The Composable Kernel (CK) library provides a programming model for writing perf
The CK documentation is structured as follows: The CK documentation is structured as follows:
.. card:: Conceptual .. grid:: 2
:gutter: 3
* :ref:`what-is-ck` .. grid-item-card:: Installation
.. card:: Installation * :ref:`docker-hub`
* :ref:`docker-hub` .. grid-item-card:: Conceptual
.. card:: Tutorial * :ref:`what-is-ck`
* :ref:`hello-world` .. grid-item-card:: API reference
.. card:: API reference * :ref:`supported-primitives`
* :ref:`api-reference`
* :ref:`wrapper`
* :ref:`supported-primitives` .. grid-item-card:: Tutorial
* :ref:`api-reference`
* :ref:`wrapper`
.. card:: Contributing to CK * :ref:`hello-world`
* :ref:`contributing-to`
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/index.html>`_. To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/index.html>`_.
......
...@@ -36,7 +36,7 @@ What is inside the image? ...@@ -36,7 +36,7 @@ What is inside the image?
The docker images have everything you need for running CK including: The docker images have everything you need for running CK including:
* `ROCm <https://www.amd.com/en/graphics/servers-solutions-rocm>`_ * `ROCm <https://rocm.docs.amd.com/en/latest/index.html>`_
* `CMake <https://cmake.org/getting-started/>`_ * `CMake <https://cmake.org/getting-started/>`_
* `Compiler <https://github.com/ROCm/llvm-project>`_ * `Compiler <https://github.com/ROCm/llvm-project>`_
* `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_ * `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_
......
```{include} ../LICENSE.md
```
.. meta::
:description: Composable Kernel documentation and API reference library
:keywords: composable kernel, CK, ROCm, API, documentation
.. _license:
********************************************************************
License
********************************************************************
.. include:: ../LICENSE
\ No newline at end of file
...@@ -64,31 +64,31 @@ Advanced examples: ...@@ -64,31 +64,31 @@ Advanced examples:
Layout Layout
------------------------------------- -------------------------------------
.. doxygenstruct:: ck::wrapper::Layout .. doxygenstruct:: Layout
------------------------------------- -------------------------------------
Layout helpers Layout helpers
------------------------------------- -------------------------------------
.. doxygenfile:: layout_utils.hpp .. doxygenfile:: include/ck/wrapper/utils/layout_utils.hpp
------------------------------------- -------------------------------------
Tensor Tensor
------------------------------------- -------------------------------------
.. doxygenstruct:: ck::wrapper::Tensor .. doxygenstruct:: Tensor
------------------------------------- -------------------------------------
Tensor helpers Tensor helpers
------------------------------------- -------------------------------------
.. doxygenfile:: tensor_utils.hpp .. doxygenfile:: include/ck/wrapper/utils/tensor_utils.hpp
.. doxygenfile:: tensor_partition.hpp .. doxygenfile:: include/ck/wrapper/utils/tensor_partition.hpp
------------------------------------- -------------------------------------
Operations Operations
------------------------------------- -------------------------------------
.. doxygenfile:: copy.hpp .. doxygenfile:: include/ck/wrapper/operations/copy.hpp
.. doxygenfile:: gemm.hpp .. doxygenfile:: include/ck/wrapper/operations/gemm.hpp
...@@ -2,20 +2,35 @@ defaults: ...@@ -2,20 +2,35 @@ defaults:
numbered: False numbered: False
root: index root: index
subtrees: subtrees:
- entries:
- file: what-is-ck.rst - caption: Conceptual
entries:
- file: conceptual/what-is-ck.rst
title: What is Composable Kernel? title: What is Composable Kernel?
- file: dockerhub.rst
- caption: Install
entries:
- file: install/dockerhub.rst
title: Docker Hub title: Docker Hub
- file: tutorial_hello_world.rst
title: Hello World Tutorial - caption: CK API Reference
- file: Supported_Primitives_Guide.rst entries:
- file: reference/Supported_Primitives_Guide.rst
title: Supported Primitives title: Supported Primitives
- file: API_Reference_Guide.rst - file: reference/API_Reference_Guide.rst
title: API Reference title: API Reference
- file: wrapper.rst - file: reference/wrapper.rst
title: Wrapper title: Wrapper
- caption: Tutorial
entries:
- file: tutorial/tutorial_hello_world.rst
title: Hello World Tutorial
- caption: About
entries:
- file: Contributors_Guide.rst - file: Contributors_Guide.rst
title: Contributing to CK title: Contributing to CK
- file: license.md - file: license.rst
title: License title: License
\ No newline at end of file
rocm-docs-core==0.35.0 rocm-docs-core==0.36.0
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -113,7 +113,7 @@ requests==2.31.0 ...@@ -113,7 +113,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.35.0 rocm-docs-core==0.36.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
...@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) ...@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
...@@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) ...@@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endforeach() endforeach()
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
...@@ -19,15 +19,50 @@ using AElementOp = PassThrough; ...@@ -19,15 +19,50 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| < ALayout,
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| BLayout,
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| CLayout,
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ADataType,
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeType>;
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
......
...@@ -5,6 +5,88 @@ ...@@ -5,6 +5,88 @@
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType> template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
...@@ -68,6 +150,22 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -68,6 +150,22 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
...@@ -240,8 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -240,8 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err( return ck::utils::check_err(c_m_n_device_result,
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1); c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#endif #endif
} }
......
...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; ...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout, ALayout,
BLayout, BLayout,
ck::Tuple<DLayout>, ck::Tuple<DLayout>,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, AccDataType,
EDataType, CShuffleDataType,
AccDataType, ck::Tuple<DDataType>,
CShuffleDataType, EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
256, 2, // Prefetch stage
128, 128, // BlockSize
256, 128, // MPerBlock
8, 64, // NPerBlock
8, 64, // KPerBlock
16, 8, // K1
16, 16, // MPerWmma
4, 16, // NPerWmma
4, 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
S<4, 64, 1>, 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<1, 0, 2>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
S<4, 64, 1>, true,
S<1, 0, 2>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
1, true,
1, 1, // C shuffle (M Repeat) Per store
S<1, 32, 1, 8>, 1, // C shuffle (N Repeat) Per store
8>; S<1, 32, 1, 4>,
8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -264,7 +265,7 @@ int main(int argc, char* argv[]) ...@@ -264,7 +265,7 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
...@@ -55,7 +55,7 @@ using DDataType = I8; ...@@ -55,7 +55,7 @@ using DDataType = I8;
using EDataType = I8; using EDataType = I8;
using ALayout = Row; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using DLayout = Row; using DLayout = Row;
using ELayout = Row; using ELayout = Row;
...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; ...@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout, ALayout,
BLayout, BLayout,
ck::Tuple<DLayout>, ck::Tuple<DLayout>,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, AccDataType,
EDataType, CShuffleDataType,
AccDataType, ck::Tuple<DDataType>,
CShuffleDataType, EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
32, 2, // Prefetch stage
16, 128, // BlockSize
16, 128, // MPerBlock
4, 64, // NPerBlock
16, 64, // KPerBlock
16, 8, // K1
16, 16, // MPerWmma
1, 16, // NPerWmma
1, 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
S<2, 16, 1>, 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<1, 0, 2>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
16, 2,
16, 8,
1, 8,
S<4, 1, 8>, true,
S<0, 2, 1>, S<4, 32, 1>,
S<0, 2, 1>, S<1, 0, 2>,
1, S<1, 0, 2>,
16, 2,
2, 8,
1, 8,
1, true,
1, 1, // C shuffle (M Repeat) Per store
S<1, 16, 1, 2>, 1, // C shuffle (N Repeat) Per store
8>; S<1, 32, 1, 4>,
8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1) set(target 1)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
...@@ -27,6 +27,88 @@ void print_helper_msg() ...@@ -27,6 +27,88 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
} }
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
...@@ -164,8 +246,11 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -164,8 +246,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err( return ck::utils::check_err(out_device,
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); out_host,
"Error: incorrect results!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
} }
return true; return true;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using ComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
ComputeDataType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment