Unverified Commit a11cf2c6 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a72e9efa 64d5c4d6
...@@ -7,6 +7,7 @@ Please describe the motivation behind the pull request, whether it enables a new ...@@ -7,6 +7,7 @@ Please describe the motivation behind the pull request, whether it enables a new
Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have added inline documentation which enables the maintainers with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
......
...@@ -106,6 +106,13 @@ if(CK_USE_CODEGEN) ...@@ -106,6 +106,13 @@ if(CK_USE_CODEGEN)
add_definitions(-DCK_USE_CODEGEN) add_definitions(-DCK_USE_CODEGEN)
endif() endif()
option(CK_TIME_KERNEL "Enable kernel time tracking" ON)
if(CK_TIME_KERNEL)
add_definitions(-DCK_TIME_KERNEL=1)
else()
add_definitions(-DCK_TIME_KERNEL=0)
endif()
include(getopt) include(getopt)
# CK version file to record release version as well as git commit hash # CK version file to record release version as well as git commit hash
...@@ -533,7 +540,13 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS ...@@ -533,7 +540,13 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
add_compile_options(-fdiagnostics-color=always) add_compile_options(-fdiagnostics-color=always)
endif() endif()
# make check runs the entire set of examples and tests
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
......
...@@ -512,7 +512,7 @@ def Build_CK(Map conf=[:]){ ...@@ -512,7 +512,7 @@ def Build_CK(Map conf=[:]){
arch_type = 5 arch_type = 5
} }
cmake_build(conf) cmake_build(conf)
if ( arch_type == 1 ){ if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){
echo "Run inductor codegen tests" echo "Run inductor codegen tests"
sh """ sh """
pip install --verbose . pip install --verbose .
......
...@@ -121,6 +121,15 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ...@@ -121,6 +121,15 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
You can find instructions for running each individual example in [example](/example). You can find instructions for running each individual example in [example](/example).
* Build and run smoke/regression examples and tests:
```bash
make -j smoke # tests and examples that run for < 30 seconds each
```
```bash
make -j regression # tests and examples that run for >= 30 seconds each
```
* Build ckProfiler: * Build ckProfiler:
```bash ```bash
......
...@@ -22,4 +22,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -22,4 +22,7 @@ if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
add_executable(grouped_conv2d_fwd_ngchw grouped_conv2d_fwd_ngchw.cpp)
target_link_libraries(grouped_conv2d_fwd_ngchw PRIVATE composable_kernel::device_conv_operations)
endif() endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::half_t>;
using InLayout = ck::tensor_layout::convolution::NGCHW;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using OutLayout = ck::tensor_layout::convolution::NGKHW;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 32;
static constexpr ck::index_t N = 64; // batch size
static constexpr ck::index_t K = 64; // output channel
static constexpr ck::index_t C = 32; // input channel (per group)
static constexpr ck::index_t Y = 3; // filter H
static constexpr ck::index_t X = 3; // filter W
static constexpr ck::index_t Hi = 14; // input H
static constexpr ck::index_t Wi = 14; // input W
static constexpr ck::index_t Ho = 14; // output H
static constexpr ck::index_t Wo = 14; // output W
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int execute_conv_fwd()
{
std::array<ck::index_t, 5> in_lengths{G, N, C, Hi, Wi};
std::array<ck::index_t, 5> in_strides{C * Hi * Wi, G * C * Hi * Wi, Hi * Wi, Wi, 1};
std::array<ck::index_t, 5> wei_lengths{G, K, C, Y, X};
std::array<ck::index_t, 5> wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C};
std::array<ck::index_t, 5> out_lengths{G, N, K, Ho, Wo};
std::array<ck::index_t, 5> out_strides{K * Ho * Wo, G * K * Ho * Wo, Ho * Wo, Wo, 1};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1};
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop =
std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K;
std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C +
sizeof(WeiDataType) * G * K * Y * X * C +
sizeof(OutDataType) * 2 * N * Ho * Wo * G * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return EXIT_FAILURE;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
int main() { return execute_conv_fwd(); }
rocm-docs-core==1.13.0 rocm-docs-core==1.14.1
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -8,6 +8,13 @@ accessible-pygments==0.0.5 ...@@ -8,6 +8,13 @@ accessible-pygments==0.0.5
# via pydata-sphinx-theme # via pydata-sphinx-theme
alabaster==0.7.16 alabaster==0.7.16
# via sphinx # via sphinx
asttokens==3.0.0
# via stack-data
attrs==24.3.0
# via
# jsonschema
# jupyter-cache
# referencing
babel==2.15.0 babel==2.15.0
# via # via
# pydata-sphinx-theme # pydata-sphinx-theme
...@@ -25,9 +32,17 @@ cffi==1.16.0 ...@@ -25,9 +32,17 @@ cffi==1.16.0
charset-normalizer==3.3.2 charset-normalizer==3.3.2
# via requests # via requests
click==8.1.7 click==8.1.7
# via sphinx-external-toc # via
# jupyter-cache
# sphinx-external-toc
comm==0.2.2
# via ipykernel
cryptography==43.0.0 cryptography==43.0.0
# via pyjwt # via pyjwt
debugpy==1.8.12
# via ipykernel
decorator==5.1.1
# via ipython
deprecated==1.2.14 deprecated==1.2.14
# via pygithub # via pygithub
docutils==0.21.2 docutils==0.21.2
...@@ -38,20 +53,56 @@ docutils==0.21.2 ...@@ -38,20 +53,56 @@ docutils==0.21.2
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
exceptiongroup==1.2.2
# via ipython
executing==2.1.0
# via stack-data
fastjsonschema==2.20.0 fastjsonschema==2.20.0
# via rocm-docs-core # via
# nbformat
# rocm-docs-core
gitdb==4.0.11 gitdb==4.0.11
# via gitpython # via gitpython
gitpython==3.1.43 gitpython==3.1.43
# via rocm-docs-core # via rocm-docs-core
greenlet==3.1.1
# via sqlalchemy
idna==3.7 idna==3.7
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
importlib-metadata==8.6.1
# via
# jupyter-cache
# myst-nb
ipykernel==6.29.5
# via myst-nb
ipython==8.31.0
# via
# ipykernel
# myst-nb
jedi==0.19.2
# via ipython
jinja2==3.1.4 jinja2==3.1.4
# via # via
# myst-parser # myst-parser
# sphinx # sphinx
jsonschema==4.23.0
# via nbformat
jsonschema-specifications==2024.10.1
# via jsonschema
jupyter-cache==1.0.1
# via myst-nb
jupyter-client==8.6.3
# via
# ipykernel
# nbclient
jupyter-core==5.7.2
# via
# ipykernel
# jupyter-client
# nbclient
# nbformat
latexcodec==3.0.0 latexcodec==3.0.0
# via pybtex # via pybtex
markdown-it-py==3.0.0 markdown-it-py==3.0.0
...@@ -60,16 +111,48 @@ markdown-it-py==3.0.0 ...@@ -60,16 +111,48 @@ markdown-it-py==3.0.0
# myst-parser # myst-parser
markupsafe==2.1.5 markupsafe==2.1.5
# via jinja2 # via jinja2
matplotlib-inline==0.1.7
# via
# ipykernel
# ipython
mdit-py-plugins==0.4.1 mdit-py-plugins==0.4.1
# via myst-parser # via myst-parser
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
myst-parser==3.0.1 myst-nb==1.1.2
# via rocm-docs-core # via rocm-docs-core
myst-parser==3.0.1
# via myst-nb
nbclient==0.10.2
# via
# jupyter-cache
# myst-nb
nbformat==5.10.4
# via
# jupyter-cache
# myst-nb
# nbclient
nest-asyncio==1.6.0
# via ipykernel
packaging==24.1 packaging==24.1
# via # via
# ipykernel
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
platformdirs==4.3.6
# via jupyter-core
prompt-toolkit==3.0.50
# via ipython
psutil==6.1.1
# via ipykernel
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pybtex==0.24.0 pybtex==0.24.0
# via # via
# pybtex-docutils # pybtex-docutils
...@@ -87,26 +170,45 @@ pygithub==2.3.0 ...@@ -87,26 +170,45 @@ pygithub==2.3.0
pygments==2.18.0 pygments==2.18.0
# via # via
# accessible-pygments # accessible-pygments
# ipython
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt[crypto]==2.8.0 pyjwt[crypto]==2.8.0
# via pygithub # via pygithub
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
python-dateutil==2.9.0.post0
# via jupyter-client
pyyaml==6.0.1 pyyaml==6.0.1
# via # via
# jupyter-cache
# myst-nb
# myst-parser # myst-parser
# pybtex # pybtex
# rocm-docs-core # rocm-docs-core
# sphinx-external-toc # sphinx-external-toc
pyzmq==26.2.0
# via
# ipykernel
# jupyter-client
referencing==0.36.1
# via
# jsonschema
# jsonschema-specifications
requests==2.32.3 requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.13.0 rocm-docs-core==1.14.1
# via -r requirements.in # via -r requirements.in
rpds-py==0.22.3
# via
# jsonschema
# referencing
six==1.16.0 six==1.16.0
# via pybtex # via
# pybtex
# python-dateutil
smmap==5.0.1 smmap==5.0.1
# via gitdb # via gitdb
snowballstemmer==2.2.0 snowballstemmer==2.2.0
...@@ -116,6 +218,7 @@ soupsieve==2.5 ...@@ -116,6 +218,7 @@ soupsieve==2.5
sphinx==7.4.7 sphinx==7.4.7
# via # via
# breathe # breathe
# myst-nb
# myst-parser # myst-parser
# pydata-sphinx-theme # pydata-sphinx-theme
# rocm-docs-core # rocm-docs-core
...@@ -149,15 +252,43 @@ sphinxcontrib-qthelp==2.0.0 ...@@ -149,15 +252,43 @@ sphinxcontrib-qthelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-serializinghtml==2.0.0 sphinxcontrib-serializinghtml==2.0.0
# via sphinx # via sphinx
sqlalchemy==2.0.37
# via jupyter-cache
stack-data==0.6.3
# via ipython
tabulate==0.9.0
# via jupyter-cache
tomli==2.0.1 tomli==2.0.1
# via sphinx # via sphinx
tornado==6.4.2
# via
# ipykernel
# jupyter-client
traitlets==5.14.3
# via
# comm
# ipykernel
# ipython
# jupyter-client
# jupyter-core
# matplotlib-inline
# nbclient
# nbformat
typing-extensions==4.12.2 typing-extensions==4.12.2
# via # via
# ipython
# myst-nb
# pydata-sphinx-theme # pydata-sphinx-theme
# pygithub # pygithub
# referencing
# sqlalchemy
urllib3==2.2.2 urllib3==2.2.2
# via # via
# pygithub # pygithub
# requests # requests
wcwidth==0.2.13
# via prompt-toolkit
wrapt==1.16.0 wrapt==1.16.0
# via deprecated # via deprecated
zipp==3.21.0
# via importlib-metadata
...@@ -48,9 +48,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) ...@@ -48,9 +48,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = float;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceComputeType = float;
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ReferenceComputeType,
ReferenceComputeType>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -5,6 +5,14 @@ include_directories(BEFORE ...@@ -5,6 +5,14 @@ include_directories(BEFORE
add_custom_target(examples) add_custom_target(examples)
# list of examples that are labelled as REGRESSION_EXAMPLE for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_EXAMPLE
set(REGRESSION_EXAMPLES
example_sparse_embedding3_forward_layernorm
)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME) function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(FILE_NAME) if(FILE_NAME)
add_dependencies(EXAMPLE_NAME FILE_NAME) add_dependencies(EXAMPLE_NAME FILE_NAME)
...@@ -107,6 +115,15 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -107,6 +115,15 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set(result 0) set(result 0)
endif() endif()
#message("add_example returns ${result}") #message("add_example returns ${result}")
if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES)
#message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}")
set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${EXAMPLE_NAME})
elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES)
#message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}")
set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${EXAMPLE_NAME})
endif()
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable EXAMPLE_NAME)
...@@ -188,8 +205,10 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -188,8 +205,10 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0) set(result 0)
endif() endif()
#message("add_example returns ${result}") #message("add_example returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable_no_testing EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir # add all example subdir
......
...@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) ...@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
......
...@@ -59,7 +59,7 @@ args: ...@@ -59,7 +59,7 @@ args:
-kname print kernel name or not (default:1) -kname print kernel name or not (default:1)
-prec_i input precision (default:fp16) -prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto) -prec_o output precision, set auto will be the same as input (default:auto)
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) -prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
...@@ -69,7 +69,7 @@ args: ...@@ -69,7 +69,7 @@ args:
``` ```
## limitations ## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
``` ```
# some case # some case
......
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
import argparse import argparse
...@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ ...@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP = {'fp32' : 'float', DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t', 'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t', 'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'} 'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}
def BOOL_MAP(b_) -> str: def BOOL_MAP(b_) -> str:
if b_: if b_:
...@@ -52,7 +53,7 @@ class layernorm_fwd_codegen: ...@@ -52,7 +53,7 @@ class layernorm_fwd_codegen:
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
...@@ -71,7 +72,7 @@ struct layernorm2d_fwd_traits_ ...@@ -71,7 +72,7 @@ struct layernorm2d_fwd_traits_
{ {
using XDataType = ck_tile::remove_cvref_t<XDataType_>; using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>; using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>; using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
...@@ -135,7 +136,7 @@ struct layernorm2d_fwd_traits_ ...@@ -135,7 +136,7 @@ struct layernorm2d_fwd_traits_
template <typename XDataType_, template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N ck_tile::index_t Repeat_N_, // each thread repeat along N
...@@ -152,7 +153,7 @@ template <typename XDataType_, ...@@ -152,7 +153,7 @@ template <typename XDataType_,
int kFusedQuant_> int kFusedQuant_>
using traits_ = layernorm2d_fwd_traits_<XDataType_, using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_, YDataType_,
XScaleDataType_, SmoothScaleDataType_,
YScaleDataType_, YScaleDataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
...@@ -170,7 +171,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_, ...@@ -170,7 +171,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
""" """
API_COMMON_HEADER = """ API_COMMON_HEADER = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
...@@ -189,9 +190,9 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -189,9 +190,9 @@ float layernorm2d_fwd_(const S& s, A a)
{{ {{
using XDataType = typename Traits_::XDataType; using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType; using YDataType = typename Traits_::YDataType;
using XScaleDataType = typename Traits_::XScaleDataType; using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType; using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN, using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
...@@ -202,16 +203,16 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -202,16 +203,16 @@ float layernorm2d_fwd_(const S& s, A a)
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XBiasDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XBiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType, typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape, typename Traits_::Shape,
PipelineTraits>; PipelineTraits>;
...@@ -224,7 +225,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -224,7 +225,7 @@ float layernorm2d_fwd_(const S& s, A a)
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
static constexpr bool UseRawStore = sizeof(YDataType) == 4; static constexpr bool UseRawStore = sizeof(YDataType) == 4;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape, using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>; ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...@@ -249,7 +250,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -249,7 +250,7 @@ float layernorm2d_fwd_(const S& s, A a)
API_BASE = """ API_BASE = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
...@@ -285,7 +286,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -285,7 +286,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
INSTANCE_BASE = """ INSTANCE_BASE = """
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
...@@ -374,7 +375,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -374,7 +375,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
class h_traits: class h_traits:
F_XDataType : str F_XDataType : str
F_YDataType : str F_YDataType : str
F_XScaleDataType : str F_SmoothScaleDataType : str
F_YScaleDataType : str F_YScaleDataType : str
F_Repeat_M : int F_Repeat_M : int
F_Repeat_N : int F_Repeat_N : int
...@@ -392,7 +393,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -392,7 +393,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
...@@ -477,8 +478,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -477,8 +478,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
if ins.F_kFusedQuant == 0: if ins.F_kFusedQuant == 0:
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
elif ins.F_kFusedQuant == 1: elif ins.F_kFusedQuant == 1:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
elif ins.F_kFusedQuant == 2: elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
...@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_traits = layernorm_fwd_codegen.h_traits h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance h_instance = layernorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8'] dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range # some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict # (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')] scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'), dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out ('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out
types_8bit = ('int8', 'fp8') types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16') types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2] #fused_add_list = [0, 1, 2]
...@@ -572,7 +574,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -572,7 +574,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list): for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',') prec_i, prec_o = dtype.split(',')
scale_x, scale_y = scale_type.split(',') scale_sm, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1: if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
continue # skip non dynamic quant case continue # skip non dynamic quant case
if fused_quant == 1 and hs_key == 'big': if fused_quant == 1 and hs_key == 'big':
...@@ -582,8 +584,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -582,8 +584,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
h_ = copy.copy(chs_) # copy the base instance out h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i h_.F_XDataType = prec_i
h_.F_YDataType = prec_o h_.F_YDataType = prec_o
h_.F_XScaleDataType = scale_y h_.F_SmoothScaleDataType = scale_sm
h_.F_YScaleDataType = scale_x h_.F_YScaleDataType = scale_y
h_.F_kXbias = xbias h_.F_kXbias = xbias
h_.F_kFusedAdd = fused_add h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant h_.F_kFusedQuant = fused_quant
......
...@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <>
auto get_elimit<ck_tile::int8_t>()
{
double rtol = 1e-2;
double atol = 1.0;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -35,7 +43,7 @@ auto create_args(int argc, char* argv[]) ...@@ -35,7 +43,7 @@ auto create_args(int argc, char* argv[])
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input") .insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_sx", .insert("prec_sm",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1") "output quant scale type, set auto will use fp32. used when fquant=1")
.insert("prec_sy", .insert("prec_sy",
...@@ -53,7 +61,7 @@ auto create_args(int argc, char* argv[]) ...@@ -53,7 +61,7 @@ auto create_args(int argc, char* argv[])
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename XScaleDataType, typename SmoothScaleDataType,
typename YScaleDataType, typename YScaleDataType,
bool SaveMeanVar> bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
...@@ -75,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -75,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_sx == "auto") if(prec_sm == "auto")
{ {
prec_sx = "fp32"; prec_sm = "fp32";
} }
if(prec_sy == "auto") if(prec_sy == "auto")
{ {
...@@ -97,15 +105,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -97,15 +105,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
int xbias = arg_parser.get_int("xbias"); int xbias = arg_parser.get_int("xbias");
int fused_add = arg_parser.get_int("fadd"); int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8") if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8")
{ {
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false; return false;
} }
assert(x_stride >= n); assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>; using TypeConfig =
LayerNormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType; using YDataType = typename TypeConfig::YDataType;
...@@ -139,12 +150,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -139,12 +150,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m}); ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m}); ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
ck_tile::HostTensor<XScaleDataType> x_scale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n}); ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host); ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
...@@ -155,7 +166,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -155,7 +166,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
...@@ -165,7 +176,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -165,7 +176,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
x_scale_buf.ToDevice(x_scale_host.data()); sm_scale_buf.ToDevice(sm_scale_host.data());
auto prec_str = [&]() { auto prec_str = [&]() {
auto base_str = prec_i; auto base_str = prec_i;
...@@ -186,11 +197,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -186,11 +197,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", yr_stride:" << yr_stride << std::flush; << ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant}; prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
x_bias_buf.GetDeviceBuffer(), x_bias_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
...@@ -279,8 +290,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -279,8 +290,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
// input smooth outlier // input smooth outlier
acc_(m_, n_) = acc_(m_, n_) = acc_(m_, n_) *
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_)); ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
} }
} }
ComputeDataType absmax = static_cast<ComputeDataType>(0); ComputeDataType absmax = static_cast<ComputeDataType>(0);
...@@ -290,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -290,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax; absmax = a > absmax ? a : absmax;
} }
// printf("cpu:absmax:%f\n", absmax); // printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0); constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale); y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++) for(int n_ = 0; n_ < N_; n_++)
{ {
...@@ -333,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -333,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
} }
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<OutDataType>();
if(x_stride == n) if(x_stride == n)
{ {
...@@ -402,16 +417,16 @@ int main(int argc, char* argv[]) ...@@ -402,16 +417,16 @@ int main(int argc, char* argv[])
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx"); std::string prec_sm = arg_parser.get_str("prec_sm");
std::string prec_sy = arg_parser.get_str("prec_sy"); std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto") if(prec_o == "auto")
{ {
prec_o = prec_i; prec_o = prec_i;
} }
if(prec_sx == "auto") if(prec_sm == "auto")
{ {
prec_sx = "fp32"; prec_sm = "fp32";
} }
if(prec_sy == "auto") if(prec_sy == "auto")
{ {
...@@ -420,37 +435,47 @@ int main(int argc, char* argv[]) ...@@ -420,37 +435,47 @@ int main(int argc, char* argv[])
int save_mv = arg_parser.get_int("save_mv"); int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case // no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
save_mv) save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
} }
// dynamic quant case, only in inference // dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv) !save_mv)
{ {
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
} }
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return -3; return -3;
} }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,37 +8,40 @@ ...@@ -8,37 +8,40 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename InType,
typename OutType,
typename SmoothSScaleDataType_,
typename YScaleDataType_>
struct LayerNormTypeConfig; struct LayerNormTypeConfig;
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_> struct LayerNormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::half_t; using XBiasDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t; using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using XScaleDataType = XScaleDataType_; using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_; using YScaleDataType = YScaleDataType_;
}; };
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_> struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = OutType; using YDataType = OutType;
using XBiasDataType = ck_tile::bf16_t; using XBiasDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t; using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
using XScaleDataType = XScaleDataType_; using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_; using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -52,10 +55,10 @@ struct layernorm2d_fwd_traits ...@@ -52,10 +55,10 @@ struct layernorm2d_fwd_traits
std::string prec_i; // input precision std::string prec_i; // input precision
std::string prec_o; // output precision std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check) // can set arbitrary(will skip check)
std::string prec_sx; // x-scale, used for [1*N] input smooth quant std::string prec_sm; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; // bool save_mean_var; //
......
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -49,7 +49,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -49,7 +49,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -61,8 +61,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -61,8 +61,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
...@@ -68,48 +95,16 @@ int run_gemm_example_with_layouts(int argc, ...@@ -68,48 +95,16 @@ int run_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals; stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
auto f_host_tensor_descriptor = stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) ck_tile::HostTensor<ADataType> a_m_k(
{ ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); ck_tile::HostTensor<BDataType> b_k_n(
} ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// TODO: add different init types // TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
...@@ -143,20 +138,29 @@ int run_gemm_example_with_layouts(int argc, ...@@ -143,20 +138,29 @@ int run_gemm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
...@@ -196,8 +200,18 @@ int run_gemm_example_with_layouts(int argc, ...@@ -196,8 +200,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C)); ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
......
...@@ -56,7 +56,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -56,7 +56,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment