Commit de1afb7b authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into lwpck-977

parents ce562aa6 f7331c60
...@@ -17,7 +17,7 @@ None ...@@ -17,7 +17,7 @@ None
- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) - Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
- Grouped convolution support for small K and C (#822 #879 #897) - Grouped convolution support for small K and C (#822 #879 #897)
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched Gemm DL (#732) - Support for Batched Gemm DL (#732)
### Changes ### Changes
......
...@@ -32,12 +32,10 @@ if (DTYPES) ...@@ -32,12 +32,10 @@ if (DTYPES)
if (DTYPES MATCHES "fp8") if (DTYPES MATCHES "fp8")
add_definitions(-DCK_ENABLE_FP8) add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON") set(CK_ENABLE_FP8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif() endif()
if (DTYPES MATCHES "bf8") if (DTYPES MATCHES "bf8")
add_definitions(-DCK_ENABLE_BF8) add_definitions(-DCK_ENABLE_BF8)
set(CK_ENABLE_BF8 "ON") set(CK_ENABLE_BF8 "ON")
add_compile_options(-Wno-bit-int-extension)
endif() endif()
if (DTYPES MATCHES "fp16") if (DTYPES MATCHES "fp16")
add_definitions(-DCK_ENABLE_FP16) add_definitions(-DCK_ENABLE_FP16)
...@@ -59,9 +57,11 @@ if (DTYPES) ...@@ -59,9 +57,11 @@ if (DTYPES)
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON") set(CK_ENABLE_ALL_DTYPES "ON")
add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8
endif() endif()
#for f8/bf8_t type
add_compile_options(-Wno-bit-int-extension)
if(DL_KERNELS) if(DL_KERNELS)
add_definitions(-DDL_KERNELS) add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON") set(CK_ENABLE_DL_KERNELS "ON")
......
...@@ -790,8 +790,8 @@ pipeline { ...@@ -790,8 +790,8 @@ pipeline {
} }
agent{ label rocmnode("navi32") } agent{ label rocmnode("navi32") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -50,12 +52,16 @@ int main(int argc, char* argv[]) ...@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -84,14 +90,21 @@ int main(int argc, char* argv[]) ...@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
{0, 1}, // gammaStrides {0, 1}, // gammaStrides
{0, 1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -109,6 +122,10 @@ int main(int argc, char* argv[]) ...@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * M * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
...@@ -140,17 +157,24 @@ int main(int argc, char* argv[]) ...@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
{Stride, 1}, // xStrides {Stride, 1}, // xStrides
{1}, // gammaStrides {0, 1}, // gammaStrides
{1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = float; using GammaDataType = float;
using BetaDataType = float; using BetaDataType = float;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -49,19 +51,24 @@ int main(int argc, char* argv[]) ...@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std::size_t xy_size = N * H * W * G * C; std::size_t xy_size = N * H * W * G * C;
std::size_t gamma_beta_size = G * C; std::size_t gamma_beta_size = G * C;
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1}; std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> save_mean_inv_std_strides = {G, 1};
SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
Swish, Swish,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -75,19 +82,26 @@ int main(int argc, char* argv[]) ...@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const auto& generic_op_ptr = op_ptrs[0]; const auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides xy_strides, // xStrides
gamma_beta_strides, // gammaStrides gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // betaStrides
xy_strides, // yStrides xy_strides, // yStrides
{1, 2, 4}, // reduceDims save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6, 1e-6,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
Swish{}); Swish{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
...@@ -107,21 +121,29 @@ int main(int argc, char* argv[]) ...@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths auto argument_ptr =
xy_strides, // xStrides op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
gamma_beta_strides, // gammaStrides xy_strides, // xStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // gammaStrides
xy_strides, // yStrides gamma_beta_strides, // betaStrides
{1, 2, 4}, // reduceDims xy_strides, // yStrides
1e-6, save_mean_inv_std_strides, // save_mean Strides
x_device_buf.GetDeviceBuffer(), save_mean_inv_std_strides, // save_inv_std Strides
gamma_device_buf.GetDeviceBuffer(), {1, 2, 4}, // reduceDims
beta_device_buf.GetDeviceBuffer(), 1e-6,
y_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
nullptr, gamma_device_buf.GetDeviceBuffer(),
nullptr, beta_device_buf.GetDeviceBuffer(),
Swish{}); y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -139,6 +161,10 @@ int main(int argc, char* argv[]) ...@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
...@@ -169,20 +195,28 @@ int main(int argc, char* argv[]) ...@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths auto argument_ptr =
xy_strides, // xStrides op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
gamma_beta_strides, // gammaStrides xy_strides, // xStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // gammaStrides
xy_strides, // yStrides gamma_beta_strides, // betaStrides
{1, 2, 4}, // reduceDims xy_strides, // yStrides
1e-6, save_mean_inv_std_strides, // save_mean Strides
x_device_buf.GetDeviceBuffer(), save_mean_inv_std_strides, // save_inv_std Strides
gamma_device_buf.GetDeviceBuffer(), {1, 2, 4}, // reduceDims
beta_device_buf.GetDeviceBuffer(), 1e-6,
y_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
nullptr, gamma_device_buf.GetDeviceBuffer(),
nullptr, beta_device_buf.GetDeviceBuffer(),
Swish{}); y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -70,6 +70,7 @@ else() ...@@ -70,6 +70,7 @@ else()
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
-Wno-unused-template
) )
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
list(APPEND CMAKE_COMPILER_WARNINGS list(APPEND CMAKE_COMPILER_WARNINGS
......
add_custom_target(example_gemm_dl) add_custom_target(example_gemm_dl)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_dl example_gemm_dl_fp32)
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
endif()
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
endif()
add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_dl example_gemm_dl_int8)
add_dependencies(example_gemm_dl example_gemm_dl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int4) add_example_dependencies(example_gemm_dl example_gemm_dl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
add_custom_target(example_gemm_xdl) add_custom_target(example_gemm_xdl)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
endif()
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
endif()
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
endif()
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
endif() endif()
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
endif()
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
endif()
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
endif()
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_add_add_fastgelu_xdl) add_custom_target(example_gemm_add_add_fastgelu_xdl)
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
set(target 1)
endif() endif()
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) endforeach()
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
endif()
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
endif()
set(target 1)
endif()
endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_reduce_xdl) add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
if(result EQUAL 0) add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
endif()
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
endif() add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
if(result EQUAL 0)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
endif() add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
if(result EQUAL 0) if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
endif() add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
if(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) set(target 1)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) endif()
endif(USE_BITINT_EXTENSION_INT4) endforeach()
set(target 1)
endif()
endforeach()
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## Run ```example_reduce_blockwise``` ## Run ```example_reduce_blockwise```
```bash ```bash
# -D <xxx> : input 3d/4d/5d tensor lengths # -D <xxx> : input 3D/4D/5D tensor lengths
# -R <xxx> : reduce dimension ids # -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes) # -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4) #arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
...@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr ...@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
## Run ```example_reduce_multiblock_atomic_add``` ## Run ```example_reduce_multiblock_atomic_add```
```bash ```bash
# -D <xxx> : input 3d/4d/5d tensor lengths # -D <xxx> : input 3D/4D/5D tensor lengths
# -R <xxx> : reduce dimension ids # -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes) # -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp32, 1: fp64) #arg1: data type (0: fp32, 1: fp64)
......
add_custom_target(example_grouped_gemm_xdl) add_custom_target(example_grouped_gemm_xdl)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
endif()
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
endif()
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
endif()
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
endif()
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
endif()
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
endif() endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemm_reduce_xdl) add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl) add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
if(result EQUAL 0) add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
endif()
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
endif() add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
endif() add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
endif() add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
endif() add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
endif() add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) add_example_dependencies(example_gemm_reduce_xdl
endif() example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) example_gemm_add_add_mean_meansquare_xdl)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) if(USE_BITINT_EXTENSION_INT4)
endif() add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
if(result EQUAL 0) endif()
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) set(target 1)
endif() endif()
add_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
endif()
set(target 1)
endif()
endforeach() endforeach()
...@@ -2,34 +2,28 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,34 +2,28 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_weight) add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
endif() add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
endif() add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) set(target 1)
if(result EQUAL 0) endif()
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
endif() if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
set(target 1) add_custom_target(example_grouped_conv_bwd_weight)
endif() add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
add_custom_target(example_grouped_conv_bwd_weight) set(target 1)
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) endif()
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
endif()
set(target 1)
endif()
endforeach() endforeach()
add_custom_target(example_grouped_conv_bwd_weight_dl) add_custom_target(example_grouped_conv_bwd_weight_dl)
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
endif()
...@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType, BetaDataType,
HDataType, HDataType,
AccDataType, AccDataType,
AccDataType,
HElementOp, HElementOp,
2, 2,
1>; 1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N}); Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> save_mean({M});
Tensor<AccDataType> save_inv_std({M});
auto ref_gemm = ReferenceGemm{}; auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker(); auto ref_gemm_invoker = ref_gemm.MakeInvoker();
...@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument( auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument); ref_layernorm_invoker.Run(ref_layernorm_argument);
} }
......
add_custom_target(example_cgemm_xdl) add_custom_target(example_cgemm_xdl)
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
endif()
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
endif()
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
endif()
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4)
endif() endif()
add_custom_target(example_batched_gemm_xdl) add_custom_target(example_batched_gemm_xdl)
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
endif()
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
endif()
add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
endif()
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
if(result EQUAL 0) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif()
endif() endif()
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#include "common.hpp" #include "common.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // OutScalarPerVector 8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc" #include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); } int main() { return run_groupnorm_example<DeviceInstance>(); }
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#include "common.hpp" #include "common.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // YScalarPerVector 8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc" #include "run_layernorm_example.inc"
......
...@@ -10,22 +10,13 @@ int run_groupnorm_example() ...@@ -10,22 +10,13 @@ int run_groupnorm_example()
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = N;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { Tensor<XDataType> x({M, N});
return HostTensorDescriptor({len}, {stride}); Tensor<GammaDataType> gamma({N});
}; Tensor<BetaDataType> beta({N});
Tensor<YDataType> y({M, N});
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { Tensor<SaveMeanInvStdDataType> save_mean({M});
using namespace ck::literals; Tensor<SaveMeanInvStdDataType> save_inv_std({M});
return HostTensorDescriptor({row, col}, {stride, 1_uz});
};
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0}); x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
...@@ -35,6 +26,11 @@ int run_groupnorm_example() ...@@ -35,6 +26,11 @@ int run_groupnorm_example()
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
#ifdef SAVE_MEAN_INV_STD
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
#endif
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -47,14 +43,23 @@ int run_groupnorm_example() ...@@ -47,14 +43,23 @@ int run_groupnorm_example()
{0, 1}, {0, 1},
{0, 1}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
...@@ -72,24 +77,45 @@ int run_groupnorm_example() ...@@ -72,24 +77,45 @@ int run_groupnorm_example()
bool pass = true; bool pass = true;
{ {
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> host_y({M, N});
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType, Tensor<SaveMeanInvStdDataType> host_save_mean({M});
GammaDataType, Tensor<SaveMeanInvStdDataType> host_save_inv_std({M});
BetaDataType,
YDataType, using ReferenceInstance =
ComputeDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
PassThrough, GammaDataType,
Rank, BetaDataType,
NumReduceDim>; YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
{M, N},
{1},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.FromDevice(save_mean.mData.data());
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(
save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3);
pass &= ck::utils::check_err(
save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3);
#endif
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) ...@@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0) if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
add_custom_target(example_grouped_conv_fwd_multiple_d) add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
if(result EQUAL 0) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
endif()
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
endif() add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
endif() add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
if(result EQUAL 0) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
endif()
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) if(USE_BITINT_EXTENSION_INT4)
if(result EQUAL 0) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() endif() # USE_BITINT_EXTENSION_INT4
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) set(target 1)
if(result EQUAL 0) endif()
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif()
endif() # USE_BITINT_EXTENSION_INT4
set(target 1)
endif()
endforeach() endforeach()
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0) if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment