Unverified Commit ac76519a authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/gemm_tile_loop

parents a70c6283 578142db
...@@ -89,6 +89,14 @@ else() ...@@ -89,6 +89,14 @@ else()
message("Building CK for the following targets: ${AMDGPU_TARGETS}") message("Building CK for the following targets: ${AMDGPU_TARGETS}")
endif() endif()
find_package(hip) find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
message("hip_version_flat=${hip_VERSION_FLAT}")
if(${hip_VERSION_FLAT} GREATER 500723302)
message("Adding the fno-offload-uniform-block compiler flag")
add_compile_options(-fno-offload-uniform-block)
endif()
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
......
...@@ -6,9 +6,11 @@ This is the list of developers and contributors to Composable Kernel library ...@@ -6,9 +6,11 @@ This is the list of developers and contributors to Composable Kernel library
## Developers ## Developers
[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2023 [Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2023
[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022 [Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2023
[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), 2022 [Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), [Astha Rai](https://github.com/arai713), [Shi YanXing](https://github.com/Yanxing-Shi), 2022-2023
[Hari Sadasivan](https://github.com/hsadasiv), [Bartlomiej Kocot](https://github.com/bartekxk), [Bartlomiej Wroblewski](https://github.com/bwroblew), 2023
Hanwen Chang, 2019-2021, Hanwen Chang, 2019-2021,
......
...@@ -710,8 +710,8 @@ pipeline { ...@@ -710,8 +710,8 @@ pipeline {
} }
agent{ label rocmnode("gfx908 || gfx90a") } agent{ label rocmnode("gfx908 || gfx90a") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940" """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
......
...@@ -32,63 +32,49 @@ struct SimpleDeviceMem ...@@ -32,63 +32,49 @@ struct SimpleDeviceMem
}; };
template <ck::index_t NumDimSpatial> template <ck::index_t NumDimSpatial>
std::size_t GetFlops(ck::index_t G, std::size_t GetFlops(const std::array<ck::index_t, NumDimSpatial>& output_lengths,
ck::index_t N, const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
{ {
constexpr ck::index_t spatial_offset = 3;
const auto C = filter_lengths[2];
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product> // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G * N * K * C * return static_cast<std::size_t>(2) * C *
std::accumulate(std::begin(output_spatial_lengths), std::accumulate(std::begin(output_lengths),
std::end(output_spatial_lengths), std::end(output_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>()) * std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths), std::accumulate(std::begin(filter_lengths) + spatial_offset,
std::end(filter_spatial_lengths), std::end(filter_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>()); std::multiplies<>());
} }
template <typename InDataType, ck::index_t NumDimSpatial> template <typename InDataType, ck::index_t NumDimSpatial>
std::size_t GetInputByte(ck::index_t G, std::size_t GetInputByte(const std::array<ck::index_t, NumDimSpatial>& input_lengths)
ck::index_t N,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
{ {
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) + // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * (G * N * C * return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths),
std::accumulate(std::begin(input_spatial_lengths), std::end(input_lengths),
std::end(input_spatial_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>())); std::multiplies<>()));
} }
template <typename WeiDataType, ck::index_t NumDimSpatial> template <typename WeiDataType, ck::index_t NumDimSpatial>
std::size_t GetWeightByte(ck::index_t G, std::size_t GetWeightByte(const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
{ {
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) + // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * (G * K * C * return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths),
std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<>())); std::multiplies<>()));
} }
template <typename OutDataType, ck::index_t NumDimSpatial> template <typename OutDataType, ck::index_t NumDimSpatial>
std::size_t GetOutputByte(ck::index_t G, std::size_t GetOutputByte(const std::array<ck::index_t, NumDimSpatial>& output_lengths)
ck::index_t N,
ck::index_t K,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
{ {
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>); // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G * N * K * return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths),
std::accumulate(std::begin(output_spatial_lengths), std::end(output_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<std::size_t>())); std::multiplies<std::size_t>()));
} }
...@@ -101,14 +87,11 @@ template <ck::index_t NumDimSpatial, ...@@ -101,14 +87,11 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout, typename WeiLayout,
typename OutLayout> typename OutLayout>
bool run_grouped_conv_bwd_weight( bool run_grouped_conv_bwd_weight(
const ck::index_t G, const std::array<ck::index_t, NumDimSpatial + 3>& input_lengths,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides, const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides, const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
...@@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight( ...@@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
{ {
ck::index_t split_k = 2; ck::index_t split_k = 2;
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths)); SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths)); SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths)); SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths));
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial, using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout, InLayout,
...@@ -143,6 +126,10 @@ bool run_grouped_conv_bwd_weight( ...@@ -143,6 +126,10 @@ bool run_grouped_conv_bwd_weight(
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_tflops = 0; float best_tflops = 0;
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NumDimSpatial + 3> b_g_k_c_xs_lengths{};
// profile device operation instances // profile device operation instances
std::cout << "Run all instances and do timing" << std::endl; std::cout << "Run all instances and do timing" << std::endl;
...@@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight( ...@@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight(
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(), wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
G, input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
...@@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight( ...@@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight(
{ {
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t flop = GetFlops<NumDimSpatial + 3>(output_lengths, filter_lengths);
GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); std::size_t num_bytes = GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths) +
std::size_t num_bytes = GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths) +
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) + GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths);
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
...@@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight( ...@@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight(
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(), wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
G, input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192; ...@@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192;
static constexpr ck::index_t X = 3; static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, 1, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, 1, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, 1, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
...@@ -40,14 +41,11 @@ int main() ...@@ -40,14 +41,11 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28; ...@@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1}; N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Y * X * C, Y* X* C, 1, X* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1}; N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
...@@ -45,14 +47,11 @@ int main() ...@@ -45,14 +47,11 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
...@@ -48,14 +50,11 @@ int main() ...@@ -48,14 +50,11 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(input_lengths,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides, input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides, output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
......
...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo}; static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{ static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1}; static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
...@@ -48,20 +50,16 @@ int main() ...@@ -48,20 +50,16 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(input_lengths,
G, input_strides,
N, filter_lengths,
K, weights_strides,
C, output_lengths,
{Di, Hi, Wi}, output_strides,
{Z, Y, X}, conv_filter_strides,
{Do, Ho, Wo}, conv_filter_dilations,
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1}, input_left_pads,
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1}, input_right_pads)
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
add_custom_target(example_gemm_dl) if(DL_KERNELS)
add_custom_target(example_gemm_dl)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_dependencies(example_gemm_dl example_gemm_dl_fp32)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_dependencies(example_gemm_dl example_gemm_dl_fp32) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_fp16) add_dependencies(example_gemm_dl example_gemm_dl_fp16)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int8) add_dependencies(example_gemm_dl example_gemm_dl_int8)
endif() endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int4) add_dependencies(example_gemm_dl example_gemm_dl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
endif()
add_custom_target(example_gemm_xdl) add_custom_target(example_gemm_xdl)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) endif()
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
...@@ -37,22 +50,20 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -37,22 +50,20 @@ if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif() endif()
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_f8) add_dependencies(example_gemm_xdl example_gemm_xdl_f8)
endif()
endif() endif()
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
using ComputeType = ck::half_t;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| |
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
...@@ -15,3 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -15,3 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
...@@ -3,22 +3,26 @@ set(target 0) ...@@ -3,22 +3,26 @@ set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_add_add_fastgelu_xdl) add_custom_target(example_gemm_add_add_fastgelu_xdl)
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -2,16 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,16 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
endif()
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) if(DL_KERNELS)
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
endif()
endif()
...@@ -3,14 +3,22 @@ set(target 0) ...@@ -3,14 +3,22 @@ set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl) add_custom_target(example_convnd_fwd_reduce_xdl)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
......
add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
# dlops # dlops
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) if(DL_KERNELS)
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
endif()
# xdlops # xdlops
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
......
add_custom_target(example_grouped_gemm_xdl) add_custom_target(example_grouped_gemm_xdl)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) endif()
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_dependencies(example_grouped_gemm_xdl
add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32
example_grouped_gemm_xdl_fp16 example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16 example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16) example_grouped_gemm_xdl_splitk_fp16)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
......
...@@ -6,32 +6,32 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -6,32 +6,32 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl) add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
add_dependencies(example_gemm_reduce_xdl_max endif()
example_gemm_max_xdl_bf16
example_gemm_max_xdl_fp16
example_gemm_max_xdl_fp32
example_gemm_max_xdl_int8)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare
example_gemm_mean_meansquare_xdl_fp16
example_gemm_mean_meansquare_xdl_fp32
example_gemm_mean_meansquare_xdl_bf16
example_gemm_add_addsquare_xdl_int8)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_dependencies(example_gemm_reduce_xdl add_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare example_gemm_reduce_xdl_mean_meansquare
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -7,5 +8,8 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -7,5 +8,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(DL_KERNELS)
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
endif()
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment