Commit 6e3cf8b0 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 4ad62d7f ba58a93f
...@@ -72,8 +72,9 @@ message(STATUS "Build with HIP ${HIP_VERSION}") ...@@ -72,8 +72,9 @@ message(STATUS "Build with HIP ${HIP_VERSION}")
rocm_create_package( rocm_create_package(
NAME CK-${CK_BACKEND} NAME composablekernel
DESCRIPTION "High Performance Composable Kernel for AMD GPUs" DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
MAINTAINER "MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
LDCONFIG LDCONFIG
) )
...@@ -226,15 +227,12 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) ...@@ -226,15 +227,12 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
configure_file("${PROJECT_SOURCE_DIR}/include/ck/hip_version.hpp.in" "${PROJECT_BINARY_DIR}/include/ck/hip_version.hpp")
include_directories(BEFORE include_directories(BEFORE
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include
${PROJECT_BINARY_DIR}/include ${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR}/library/include ${PROJECT_SOURCE_DIR}/library/include
) )
include(googletest)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
...@@ -243,7 +241,31 @@ if(BUILD_DEV) ...@@ -243,7 +241,31 @@ if(BUILD_DEV)
endif() endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
add_subdirectory(library) add_subdirectory(library)
add_subdirectory(example) add_subdirectory(example)
add_subdirectory(test) add_subdirectory(test)
add_subdirectory(profiler) add_subdirectory(profiler)
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
set(version 1.0.0)
write_basic_package_version_file(
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
VERSION "${version}"
COMPATIBILITY AnyNewerVersion
)
configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
NO_CHECK_REQUIRED_COMPONENTS_MACRO
)
install(FILES
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations host_tensor)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
set(composable_kernel_FOUND False)
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
endif()
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
endforeach()
...@@ -11,13 +11,7 @@ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ...@@ -11,13 +11,7 @@ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
RUN apt-get update RUN apt-get update
RUN apt-get install -y wget gnupg RUN apt-get install -y wget gnupg
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
RUN if ! [ -z $OSDB_BKC_VERSION ]; then \ RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
echo "Using BKC VERISION: $OSDB_BKC_VERSION";\
sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-deb/ compute-rocm-dkms-no-npi-hipclang ${OSDB_BKC_VERSION} > /etc/apt/sources.list.d/rocm.list" ;\
cat /etc/apt/sources.list.d/rocm.list;\
else \
sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" ;\
fi
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list"
...@@ -25,18 +19,15 @@ RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/ap ...@@ -25,18 +19,15 @@ RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/ap
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
sshpass \
build-essential \ build-essential \
cmake-data=3.15.1-0kitware1 \ cmake-data=3.15.1-0kitware1 \
cmake=3.15.1-0kitware1 \ cmake=3.15.1-0kitware1 \
curl \ curl \
doxygen \
g++ \ g++ \
gdb \ gdb \
git \ git \
hip-rocclr \ hip-rocclr \
jq \ jq \
lcov \
libelf-dev \ libelf-dev \
libncurses5-dev \ libncurses5-dev \
libnuma-dev \ libnuma-dev \
...@@ -62,8 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -62,8 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# RUN pip3 install --default-timeout=100000 -r requirements.txt
# Setup ubsan environment to printstacktrace # Setup ubsan environment to printstacktrace
RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
ENV UBSAN_OPTIONS=print_stacktrace=1 ENV UBSAN_OPTIONS=print_stacktrace=1
...@@ -92,5 +81,3 @@ ADD rbuild.ini /rbuild.ini ...@@ -92,5 +81,3 @@ ADD rbuild.ini /rbuild.ini
ADD dev-requirements.txt dev-requirements.txt ADD dev-requirements.txt dev-requirements.txt
RUN rbuild prepare -s develop -d $PREFIX RUN rbuild prepare -s develop -d $PREFIX
RUN groupadd -f render RUN groupadd -f render
# RUN cget install -f min-requirements.txt
# RUN CXXFLAGS='-isystem $PREFIX/include' cget install -f ./mlir-requirements.txt
...@@ -320,7 +320,7 @@ pipeline { ...@@ -320,7 +320,7 @@ pipeline {
{ {
agent{ label rocmnode("gfx908")} agent{ label rocmnode("gfx908")}
environment{ environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx900 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
} }
steps{ steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
...@@ -341,6 +341,23 @@ pipeline { ...@@ -341,6 +341,23 @@ pipeline {
} }
} }
stage("Client App")
{
parallel
{
stage("Run Client App")
{
agent{ label rocmnode("gfx908")}
environment{
setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """
}
steps{
buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
}
}
}
}
stage("Performance Tests") stage("Performance Tests")
{ {
parallel parallel
......
...@@ -43,3 +43,13 @@ Instructions for running each individual examples are under ```example/``` ...@@ -43,3 +43,13 @@ Instructions for running each individual examples are under ```example/```
make -j ckProfiler make -j ckProfiler
``` ```
Instructions for running ckProfiler are under ```profiler/``` Instructions for running ckProfiler are under ```profiler/```
## Caveat
### Kernel Timing and Verification
CK's own kernel timer will warn up kernel once, and then run it multiple times
to get average kernel time. For some kernels that use atomic add, this will cause
output buffer to be accumulated multiple times, causing verfication failure.
To work around it, do not use CK's own timer and do verification at the same time.
CK's own timer and verification in each example and ckProfiler can be enabled or
disabled from command line.
...@@ -19,6 +19,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS ...@@ -19,6 +19,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-zero-as-null-pointer-constant -Wno-zero-as-null-pointer-constant
-Wno-unused-member-function -Wno-unused-member-function
-Wno-comma -Wno-comma
-Wno-old-style-cast
) )
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
...@@ -35,4 +36,4 @@ FetchContent_MakeAvailable(googletest) ...@@ -35,4 +36,4 @@ FetchContent_MakeAvailable(googletest)
target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
...@@ -88,9 +88,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -88,9 +88,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -105,13 +105,13 @@ int main(int argc, char* argv[]) ...@@ -105,13 +105,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -125,7 +125,7 @@ int main(int argc, char* argv[]) ...@@ -125,7 +125,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -198,7 +198,7 @@ int main(int argc, char* argv[]) ...@@ -198,7 +198,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -232,7 +232,7 @@ int main(int argc, char* argv[]) ...@@ -232,7 +232,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -56,9 +56,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -56,9 +56,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -73,13 +73,13 @@ int main(int argc, char* argv[]) ...@@ -73,13 +73,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -93,7 +93,7 @@ int main(int argc, char* argv[]) ...@@ -93,7 +93,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -171,7 +171,7 @@ int main(int argc, char* argv[]) ...@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -196,7 +196,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -100,13 +100,13 @@ int main(int argc, char* argv[]) ...@@ -100,13 +100,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -120,7 +120,7 @@ int main(int argc, char* argv[]) ...@@ -120,7 +120,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -194,7 +194,7 @@ int main(int argc, char* argv[]) ...@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -219,7 +219,7 @@ int main(int argc, char* argv[]) ...@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -86,9 +86,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<AD ...@@ -86,9 +86,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<AD
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -106,13 +106,13 @@ int main(int argc, char* argv[]) ...@@ -106,13 +106,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 6) else if(argc == 6)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]); alpha = std::stof(argv[4]);
beta = std::stof(argv[5]); beta = std::stof(argv[5]);
...@@ -121,7 +121,7 @@ int main(int argc, char* argv[]) ...@@ -121,7 +121,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -138,7 +138,7 @@ int main(int argc, char* argv[]) ...@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n");
exit(0); exit(0);
} }
...@@ -216,7 +216,7 @@ int main(int argc, char* argv[]) ...@@ -216,7 +216,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -246,6 +246,8 @@ int main(int argc, char* argv[]) ...@@ -246,6 +246,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0;
} }
...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActiv ...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActiv
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -100,13 +100,13 @@ int main(int argc, char* argv[]) ...@@ -100,13 +100,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -120,7 +120,7 @@ int main(int argc, char* argv[]) ...@@ -120,7 +120,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -206,7 +206,7 @@ int main(int argc, char* argv[]) ...@@ -206,7 +206,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
...@@ -232,6 +232,8 @@ int main(int argc, char* argv[]) ...@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0;
} }
...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ...@@ -83,9 +83,9 @@ using ReferenceGemmInstance =
CElementOp>; CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -101,13 +101,13 @@ int main(int argc, char* argv[]) ...@@ -101,13 +101,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 11)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -122,7 +122,7 @@ int main(int argc, char* argv[]) ...@@ -122,7 +122,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n");
exit(0); exit(0);
} }
...@@ -218,7 +218,7 @@ int main(int argc, char* argv[]) ...@@ -218,7 +218,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M +
...@@ -250,6 +250,8 @@ int main(int argc, char* argv[]) ...@@ -250,6 +250,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0;
} }
...@@ -93,7 +93,7 @@ void PrintUseMsg() ...@@ -93,7 +93,7 @@ void PrintUseMsg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "Following arguments:\n" << "Following arguments:\n"
<< " N, K, C, \n" << " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
...@@ -165,9 +165,9 @@ int main(int argc, char* argv[]) ...@@ -165,9 +165,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
const int num_dim_spatial = 2; const int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -176,7 +176,7 @@ int main(int argc, char* argv[]) ...@@ -176,7 +176,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
if(argc >= 5) if(argc >= 5)
...@@ -269,7 +269,7 @@ int main(int argc, char* argv[]) ...@@ -269,7 +269,7 @@ int main(int argc, char* argv[])
"not support this problem"); "not support this problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
...@@ -305,7 +305,8 @@ int main(int argc, char* argv[]) ...@@ -305,7 +305,8 @@ int main(int argc, char* argv[])
OutElementOp{}); OutElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return 0;
} }
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) # FIXME: should fix validation failure
add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util) target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util)
...@@ -90,7 +90,7 @@ void PrintUseMsg() ...@@ -90,7 +90,7 @@ void PrintUseMsg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "Following arguments:\n" << "Following arguments:\n"
<< " N, K, C, \n" << " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
...@@ -162,9 +162,9 @@ int main(int argc, char* argv[]) ...@@ -162,9 +162,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
const int num_dim_spatial = 2; const int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -173,7 +173,7 @@ int main(int argc, char* argv[]) ...@@ -173,7 +173,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
if(argc >= 5) if(argc >= 5)
...@@ -280,7 +280,7 @@ int main(int argc, char* argv[]) ...@@ -280,7 +280,7 @@ int main(int argc, char* argv[])
"not support this problem"); "not support this problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
...@@ -320,7 +320,8 @@ int main(int argc, char* argv[]) ...@@ -320,7 +320,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return 0;
} }
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_util)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util)
...@@ -43,10 +43,10 @@ template <ck::index_t NumDimSpatial> ...@@ -43,10 +43,10 @@ template <ck::index_t NumDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device:: using DeviceConvNDFwdInstance = ck::tensor_operation::device::
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
// clang-format off // clang-format off
InDataType, // InDataType, //
WeiDataType, // WeiDataType, //
OutDataType, // OutDataType, //
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
...@@ -110,7 +110,7 @@ void print_use_msg() ...@@ -110,7 +110,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -182,9 +182,9 @@ int main(int argc, char* argv[]) ...@@ -182,9 +182,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -193,7 +193,7 @@ int main(int argc, char* argv[]) ...@@ -193,7 +193,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
} }
...@@ -277,7 +277,7 @@ int main(int argc, char* argv[]) ...@@ -277,7 +277,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
...@@ -312,8 +312,8 @@ int main(int argc, char* argv[]) ...@@ -312,8 +312,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
}; };
switch(num_dim_spatial) switch(num_dim_spatial)
...@@ -338,4 +338,5 @@ int main(int argc, char* argv[]) ...@@ -338,4 +338,5 @@ int main(int argc, char* argv[])
} }
} }
} }
return 0;
} }
...@@ -39,10 +39,10 @@ template <ck::index_t NumDimSpatial> ...@@ -39,10 +39,10 @@ template <ck::index_t NumDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device:: using DeviceConvNDFwdInstance = ck::tensor_operation::device::
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
// clang-format off // clang-format off
InDataType, // InDataType, //
WeiDataType, // WeiDataType, //
OutDataType, // OutDataType, //
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
...@@ -107,7 +107,7 @@ void print_use_msg() ...@@ -107,7 +107,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -179,9 +179,9 @@ int main(int argc, char* argv[]) ...@@ -179,9 +179,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -190,7 +190,7 @@ int main(int argc, char* argv[]) ...@@ -190,7 +190,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
} }
...@@ -276,7 +276,7 @@ int main(int argc, char* argv[]) ...@@ -276,7 +276,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
...@@ -311,8 +311,13 @@ int main(int argc, char* argv[]) ...@@ -311,8 +311,13 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(device_output.mData,
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); host_output.mData,
"Error: incorrect results!",
1e-5f,
1e-4f)
? 0
: 1;
}; };
switch(num_dim_spatial) switch(num_dim_spatial)
...@@ -337,4 +342,5 @@ int main(int argc, char* argv[]) ...@@ -337,4 +342,5 @@ int main(int argc, char* argv[])
} }
} }
} }
return 0;
} }
...@@ -45,10 +45,10 @@ template <ck::index_t NumDimSpatial> ...@@ -45,10 +45,10 @@ template <ck::index_t NumDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device:: using DeviceConvNDFwdInstance = ck::tensor_operation::device::
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
// clang-format off // clang-format off
InDataType, // InDataType, //
WeiDataType, // WeiDataType, //
OutDataType, // OutDataType, //
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
...@@ -112,7 +112,7 @@ void print_use_msg() ...@@ -112,7 +112,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -184,9 +184,9 @@ int main(int argc, char* argv[]) ...@@ -184,9 +184,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -195,7 +195,7 @@ int main(int argc, char* argv[]) ...@@ -195,7 +195,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
} }
...@@ -279,7 +279,7 @@ int main(int argc, char* argv[]) ...@@ -279,7 +279,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
...@@ -314,8 +314,8 @@ int main(int argc, char* argv[]) ...@@ -314,8 +314,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
}; };
switch(num_dim_spatial) switch(num_dim_spatial)
...@@ -340,4 +340,5 @@ int main(int argc, char* argv[]) ...@@ -340,4 +340,5 @@ int main(int argc, char* argv[])
} }
} }
} }
return 0;
} }
...@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat ...@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// Conv shape // Conv shape
ck::index_t N = 128; ck::index_t N = 128;
...@@ -102,13 +102,13 @@ int main(int argc, char* argv[]) ...@@ -102,13 +102,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 19) else if(argc == 19)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
K = std::stoi(argv[5]); K = std::stoi(argv[5]);
...@@ -130,7 +130,7 @@ int main(int argc, char* argv[]) ...@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
...@@ -214,7 +214,7 @@ int main(int argc, char* argv[]) ...@@ -214,7 +214,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
...@@ -249,6 +249,10 @@ int main(int argc, char* argv[]) ...@@ -249,6 +249,10 @@ int main(int argc, char* argv[])
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData); return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
in_n_c_hi_wi_host_result.mData)
? 0
: 1;
} }
return 0;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment