Commit b134b7d6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents 090ba885 9f71ff48
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.14)
# Check support for CUDA/HIP in Cmake # Check support for CUDA/HIP in Cmake
project(composable_kernel) project(composable_kernel)
...@@ -27,6 +27,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) ...@@ -27,6 +27,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
option(CK_TIME_KERNEL "Turning off will disable kernel timing globally" ON)
## OpenMP ## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# workaround issue hipcc in rocm3.5 cannot find openmp # workaround issue hipcc in rocm3.5 cannot find openmp
...@@ -72,8 +74,9 @@ message(STATUS "Build with HIP ${HIP_VERSION}") ...@@ -72,8 +74,9 @@ message(STATUS "Build with HIP ${HIP_VERSION}")
rocm_create_package( rocm_create_package(
NAME CK-${CK_BACKEND} NAME composablekernel
DESCRIPTION "High Performance Composable Kernel for AMD GPUs" DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
MAINTAINER "MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
LDCONFIG LDCONFIG
) )
...@@ -226,7 +229,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) ...@@ -226,7 +229,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
configure_file("${PROJECT_SOURCE_DIR}/include/ck/hip_version.hpp.in" "${PROJECT_BINARY_DIR}/include/ck/hip_version.hpp") configure_file("${PROJECT_SOURCE_DIR}/include/ck/options.hpp.in" "${PROJECT_BINARY_DIR}/include/ck/options.hpp")
include_directories(BEFORE include_directories(BEFORE
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include
...@@ -234,6 +237,7 @@ include_directories(BEFORE ...@@ -234,6 +237,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include ${PROJECT_SOURCE_DIR}/library/include
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
add_compile_options(-Werror) add_compile_options(-Werror)
...@@ -241,7 +245,31 @@ if(BUILD_DEV) ...@@ -241,7 +245,31 @@ if(BUILD_DEV)
endif() endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
add_subdirectory(library) add_subdirectory(library)
add_subdirectory(example) add_subdirectory(example)
add_subdirectory(test) add_subdirectory(test)
add_subdirectory(profiler) add_subdirectory(profiler)
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
set(version 1.0.0)
write_basic_package_version_file(
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
VERSION "${version}"
COMPATIBILITY AnyNewerVersion
)
configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
NO_CHECK_REQUIRED_COMPONENTS_MACRO
)
install(FILES
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations host_tensor)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
set(composable_kernel_FOUND False)
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
endif()
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
endforeach()
FROM ubuntu:18.04 FROM ubuntu:18.04
ARG ROCMVERSION=5.0 ARG ROCMVERSION=5.1
ARG OSDB_BKC_VERSION ARG OSDB_BKC_VERSION
RUN set -xe RUN set -xe
...@@ -11,13 +11,7 @@ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ...@@ -11,13 +11,7 @@ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
RUN apt-get update RUN apt-get update
RUN apt-get install -y wget gnupg RUN apt-get install -y wget gnupg
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
RUN if ! [ -z $OSDB_BKC_VERSION ]; then \ RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
echo "Using BKC VERISION: $OSDB_BKC_VERSION";\
sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-deb/ compute-rocm-dkms-no-npi-hipclang ${OSDB_BKC_VERSION} > /etc/apt/sources.list.d/rocm.list" ;\
cat /etc/apt/sources.list.d/rocm.list;\
else \
sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" ;\
fi
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list"
...@@ -25,18 +19,15 @@ RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/ap ...@@ -25,18 +19,15 @@ RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/ap
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
sshpass \
build-essential \ build-essential \
cmake-data=3.15.1-0kitware1 \ cmake-data=3.15.1-0kitware1 \
cmake=3.15.1-0kitware1 \ cmake=3.15.1-0kitware1 \
curl \ curl \
doxygen \
g++ \ g++ \
gdb \ gdb \
git \ git \
hip-rocclr \ hip-rocclr \
jq \ jq \
lcov \
libelf-dev \ libelf-dev \
libncurses5-dev \ libncurses5-dev \
libnuma-dev \ libnuma-dev \
...@@ -62,8 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -62,8 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# RUN pip3 install --default-timeout=100000 -r requirements.txt
# Setup ubsan environment to printstacktrace # Setup ubsan environment to printstacktrace
RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
ENV UBSAN_OPTIONS=print_stacktrace=1 ENV UBSAN_OPTIONS=print_stacktrace=1
...@@ -92,5 +81,3 @@ ADD rbuild.ini /rbuild.ini ...@@ -92,5 +81,3 @@ ADD rbuild.ini /rbuild.ini
ADD dev-requirements.txt dev-requirements.txt ADD dev-requirements.txt dev-requirements.txt
RUN rbuild prepare -s develop -d $PREFIX RUN rbuild prepare -s develop -d $PREFIX
RUN groupadd -f render RUN groupadd -f render
# RUN cget install -f min-requirements.txt
# RUN CXXFLAGS='-isystem $PREFIX/include' cget install -f ./mlir-requirements.txt
...@@ -140,6 +140,10 @@ def reboot(){ ...@@ -140,6 +140,10 @@ def reboot(){
build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),] build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),]
} }
def buildHipClangJobAndReboot(Map conf=[:]){ def buildHipClangJobAndReboot(Map conf=[:]){
try{ try{
buildHipClangJob(conf) buildHipClangJob(conf)
...@@ -156,6 +160,93 @@ def buildHipClangJobAndReboot(Map conf=[:]){ ...@@ -156,6 +160,93 @@ def buildHipClangJobAndReboot(Map conf=[:]){
} }
} }
def runCKProfiler(Map conf=[:]){
show_node_info()
env.HSA_ENABLE_SDMA=0
checkout scm
def image = "composable_kernels"
def prefixpath = conf.get("prefixpath", "/opt/rocm")
def gpu_arch = conf.get("gpu_arch", "gfx908")
// Jenkins is complaining about the render group
// def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1"
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
try {
retimage = docker.build("${image}", dockerArgs + '.')
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
echo "The job was cancelled or aborted"
throw e
}
catch(Exception ex) {
retimage = docker.build("${image}", dockerArgs + "--no-cache .")
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 5, unit: 'HOURS')
{
cmake_build(conf)
dir("script"){
def perf_log = "perf_gemm_${gpu_arch}.log"
def artifact = "profile_gemm_${gpu_arch}.txt"
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log} || true"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
sh "python3 parse_perf_data.py ${perf_log} | tee ${artifact}"
}
}
}
}
return retimage
}
def runPerfTest(Map conf=[:]){
try{
runCKProfiler(conf)
}
catch(e){
echo "throwing error exception in performance tests"
echo 'Exception occurred: ' + e.toString()
throw e
}
finally{
if (!conf.get("no_reboot", false)) {
reboot()
}
}
}
pipeline { pipeline {
agent none agent none
options { options {
...@@ -178,18 +269,19 @@ pipeline { ...@@ -178,18 +269,19 @@ pipeline {
// buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug') // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
// } // }
// } // }
stage('Build Profiler: Release, gfx908') // we will build and run ckProfiler release version later, during the performance test stage
{ //stage('Build Profiler: Release, gfx908')
agent { label rocmnode("nogpu")} //{
environment{ // agent { label rocmnode("nogpu")}
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ // environment{
} // setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
steps{ // }
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') // steps{
} // buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
} // }
//}
stage('Build Profiler: Debug, gfx908') stage('Build Profiler: Debug, gfx908')
{ {
agent { label rocmnode("nogpu")} agent { label rocmnode("nogpu")}
environment{ environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
...@@ -228,7 +320,18 @@ pipeline { ...@@ -228,7 +320,18 @@ pipeline {
{ {
agent{ label rocmnode("gfx908")} agent{ label rocmnode("gfx908")}
environment{ environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx900 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
}
}
stage("Run Tests: gfx90a")
{
agent{ label rocmnode("gfx90a")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
} }
steps{ steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
...@@ -238,6 +341,41 @@ pipeline { ...@@ -238,6 +341,41 @@ pipeline {
} }
} }
stage("Client App")
{
parallel
{
stage("Run Client App")
{
agent{ label rocmnode("gfx908")}
environment{
setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """
}
steps{
buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
}
}
}
}
stage("Performance Tests")
{
parallel
{
stage("Run ckProfiler: gfx908")
{
agent{ label rocmnode("gfx908")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
}
}
}
}
// enable after the cmake file supports packaging // enable after the cmake file supports packaging
// stage("Packages") { // stage("Packages") {
// when { // when {
......
...@@ -20,7 +20,7 @@ mkdir build && cd build ...@@ -20,7 +20,7 @@ mkdir build && cd build
cmake \ cmake \
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 \ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
.. ..
...@@ -43,3 +43,13 @@ Instructions for running each individual examples are under ```example/``` ...@@ -43,3 +43,13 @@ Instructions for running each individual examples are under ```example/```
make -j ckProfiler make -j ckProfiler
``` ```
Instructions for running ckProfiler are under ```profiler/``` Instructions for running ckProfiler are under ```profiler/```
## Caveat
### Kernel Timing and Verification
CK's own kernel timer will warn up kernel once, and then run it multiple times
to get average kernel time. For some kernels that use atomic add, this will cause
output buffer to be accumulated multiple times, causing verfication failure.
To work around it, do not use CK's own timer and do verification at the same time.
CK's own timer and verification in each example and ckProfiler can be enabled or
disabled from command line.
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-sign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
) )
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
......
include(FetchContent)
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
if(GOOGLETEST_DIR)
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
endif()
message(STATUS "Fetching GoogleTest")
list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-undef
-Wno-reserved-identifier
-Wno-global-constructors
-Wno-missing-noreturn
-Wno-disabled-macro-expansion
-Wno-used-but-marked-unused
-Wno-switch-enum
-Wno-zero-as-null-pointer-constant
-Wno-unused-member-function
-Wno-comma
-Wno-old-style-cast
)
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
)
# Will be necessary for windows build
# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
ADataType, // ADataType <ALayout, // typename ALayout
BDataType, // BDataType BLayout, // typename BLayout
CDataType, // CDataType CLayout, // typename CLayout
AccDataType, // AccDataType ADataType, // typename ADataType
CDataType, // CShuffleDataType BDataType, // typename BDataType
ALayout, // ALayout CDataType, // typename CDataType
BLayout, // BLayout AccDataType, // typename GemmAccDataType
CLayout, // CLayout CDataType, // typename CShuffleDataType
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation
PassThrough, // CElementwiseOperation PassThrough, // typename CElementwiseOperation
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage
128, // NPerBlock 256, // index_t BlockSize
32, // KPerBlock 256, // index_t MPerBlock
8, // AK1 128, // index_t NPerBlock
8, // BK1 32, // index_t KPerBlock
32, // MPerXDL 8, // index_t AK1
32, // NPerXDL 8, // index_t BK1
4, // MXdlPerWave 32, // index_t MPerXDL
2, // NXdlPerWave 32, // index_t NPerXDL
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
8, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
8, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim
true, // ABlockLdsAddExtraM 8, // index_t ABlockTransferSrcScalarPerVector
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 8, // index_t ABlockTransferDstScalarPerVector_AK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // index_t ABlockLdsExtraM
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
8, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
8, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1
1, // CShuffleNXdlPerWavePerShuffle 1, // index_t BBlockLdsExtraN
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -85,9 +88,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -85,9 +88,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -102,13 +105,13 @@ int main(int argc, char* argv[]) ...@@ -102,13 +105,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -122,7 +125,7 @@ int main(int argc, char* argv[]) ...@@ -122,7 +125,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -195,7 +198,7 @@ int main(int argc, char* argv[]) ...@@ -195,7 +198,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -229,7 +232,7 @@ int main(int argc, char* argv[]) ...@@ -229,7 +232,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
...@@ -12,7 +11,6 @@ ...@@ -12,7 +11,6 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
...@@ -46,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -46,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -58,9 +56,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -58,9 +56,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -75,13 +73,13 @@ int main(int argc, char* argv[]) ...@@ -75,13 +73,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -95,7 +93,7 @@ int main(int argc, char* argv[]) ...@@ -95,7 +93,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -173,7 +171,7 @@ int main(int argc, char* argv[]) ...@@ -173,7 +171,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -198,7 +196,7 @@ int main(int argc, char* argv[]) ...@@ -198,7 +196,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -20,64 +19,63 @@ ...@@ -20,64 +19,63 @@
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int32_t; using CDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = int8_t;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ADataType, // ADataType ALayout, // typename ALayout
BDataType, // BDataType BLayout, // typename BLayout
CDataType, // CDataType CLayout, // typename CLayout
AccDataType, // AccDataType ADataType, // typename ADataType
CShuffleDataType, // CShuffleDataType BDataType, // typename BDataType
ALayout, // ALayout CDataType, // typename CDataType
BLayout, // BLayout AccDataType, // typename GemmAccDataType
CLayout, // CLayout CShuffleDataType, // typename CShuffleDataType
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation
PassThrough, // CElementwiseOperation PassThrough, // typename CElementwiseOperation
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage
128, // NPerBlock 256, // index_t BlockSize
64, // KPerBlock 256, // index_t MPerBlock
16, // AK1 128, // index_t NPerBlock
16, // BK1 64, // index_t KPerBlock
32, // MPerXDL 16, // index_t AK1
32, // NPerXDL 16, // index_t BK1
4, // MXdlPerWave 32, // index_t MPerXDL
2, // NXdlPerWave 32, // index_t NPerXDL
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
16, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
16, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim
true, // ABlockLdsAddExtraM 16, // index_t ABlockTransferSrcScalarPerVector
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 16, // index_t ABlockTransferDstScalarPerVector_AK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // index_t ABlockLdsExtraM
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
16, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
16, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1
1, // CShuffleNXdlPerWavePerShuffle 1, // index_t BBlockLdsExtraN
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -85,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -85,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -102,13 +100,13 @@ int main(int argc, char* argv[]) ...@@ -102,13 +100,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -122,7 +120,7 @@ int main(int argc, char* argv[]) ...@@ -122,7 +120,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -196,7 +194,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +194,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -221,7 +219,7 @@ int main(int argc, char* argv[]) ...@@ -221,7 +219,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -86,9 +86,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<AD ...@@ -86,9 +86,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<AD
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -106,13 +106,13 @@ int main(int argc, char* argv[]) ...@@ -106,13 +106,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 6) else if(argc == 6)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]); alpha = std::stof(argv[4]);
beta = std::stof(argv[5]); beta = std::stof(argv[5]);
...@@ -121,7 +121,7 @@ int main(int argc, char* argv[]) ...@@ -121,7 +121,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -138,7 +138,7 @@ int main(int argc, char* argv[]) ...@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n");
exit(0); exit(0);
} }
...@@ -216,7 +216,7 @@ int main(int argc, char* argv[]) ...@@ -216,7 +216,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -246,6 +246,8 @@ int main(int argc, char* argv[]) ...@@ -246,6 +246,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0;
} }
...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActiv ...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActiv
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -100,13 +100,13 @@ int main(int argc, char* argv[]) ...@@ -100,13 +100,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -120,7 +120,7 @@ int main(int argc, char* argv[]) ...@@ -120,7 +120,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -206,7 +206,7 @@ int main(int argc, char* argv[]) ...@@ -206,7 +206,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
...@@ -232,6 +232,8 @@ int main(int argc, char* argv[]) ...@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0;
} }
...@@ -83,9 +83,9 @@ using ReferenceGemmInstance = ...@@ -83,9 +83,9 @@ using ReferenceGemmInstance =
CElementOp>; CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -101,13 +101,13 @@ int main(int argc, char* argv[]) ...@@ -101,13 +101,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 11)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -122,7 +122,7 @@ int main(int argc, char* argv[]) ...@@ -122,7 +122,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n");
exit(0); exit(0);
} }
...@@ -218,7 +218,7 @@ int main(int argc, char* argv[]) ...@@ -218,7 +218,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M +
...@@ -250,6 +250,8 @@ int main(int argc, char* argv[]) ...@@ -250,6 +250,8 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0;
} }
add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util) target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_util)
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
...@@ -93,7 +93,7 @@ void PrintUseMsg() ...@@ -93,7 +93,7 @@ void PrintUseMsg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "Following arguments:\n" << "Following arguments:\n"
<< " N, K, C, \n" << " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
...@@ -120,40 +120,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) ...@@ -120,40 +120,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[])
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
int arg_idx = 4; int arg_idx = 4;
params.num_dim_spatial = num_dim_spatial; params.num_dim_spatial_ = num_dim_spatial;
params.N = std::stoi(argv[arg_idx++]); params.N_ = std::stoi(argv[arg_idx++]);
params.K = std::stoi(argv[arg_idx++]); params.K_ = std::stoi(argv[arg_idx++]);
params.C = std::stoi(argv[arg_idx++]); params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths.resize(num_dim_spatial); params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_spatial_lengths.resize(num_dim_spatial); params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_strides.resize(num_dim_spatial); params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_dilations.resize(num_dim_spatial); params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_left_pads.resize(num_dim_spatial); params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_left_pads[i] = std::stoi(argv[arg_idx++]); params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_right_pads.resize(num_dim_spatial); params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_right_pads[i] = std::stoi(argv[arg_idx++]); params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
} }
return params; return params;
...@@ -165,9 +165,9 @@ int main(int argc, char* argv[]) ...@@ -165,9 +165,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
const int num_dim_spatial = 2; const int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -176,7 +176,7 @@ int main(int argc, char* argv[]) ...@@ -176,7 +176,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
if(argc >= 5) if(argc >= 5)
...@@ -184,21 +184,21 @@ int main(int argc, char* argv[]) ...@@ -184,21 +184,21 @@ int main(int argc, char* argv[])
params = ParseConvParams(argc, argv); params = ParseConvParams(argc, argv);
} }
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims), input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths), std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths)); std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K), std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims), filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths), std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths)); std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths(); const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K)}; static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims), output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
...@@ -211,7 +211,7 @@ int main(int argc, char* argv[]) ...@@ -211,7 +211,7 @@ int main(int argc, char* argv[])
get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
// bias: assume contiguous 1d vector // bias: assume contiguous 1d vector
Tensor<OutDataType> bias( Tensor<OutDataType> bias(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K_)})));
std::cout << "input: " << input.mDesc << std::endl; std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weights: " << weights.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl;
...@@ -248,16 +248,16 @@ int main(int argc, char* argv[]) ...@@ -248,16 +248,16 @@ int main(int argc, char* argv[])
static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()), static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
...@@ -269,18 +269,18 @@ int main(int argc, char* argv[]) ...@@ -269,18 +269,18 @@ int main(int argc, char* argv[])
"not support this problem"); "not support this problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = std::size_t num_btype =
get_btype<InDataType, WeiDataType, OutDataType>(params.N, get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
params.C, params.C_,
params.K, params.K_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths) + output_spatial_lengths) +
sizeof(OutDataType) * (params.K); sizeof(OutDataType) * (params.K_);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
...@@ -296,16 +296,17 @@ int main(int argc, char* argv[]) ...@@ -296,16 +296,17 @@ int main(int argc, char* argv[])
weights, weights,
host_output, host_output,
bias, bias,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return 0;
} }
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) # FIXME: should fix validation failure
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util) add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util)
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
...@@ -90,7 +90,7 @@ void PrintUseMsg() ...@@ -90,7 +90,7 @@ void PrintUseMsg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "Following arguments:\n" << "Following arguments:\n"
<< " N, K, C, \n" << " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
...@@ -117,40 +117,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) ...@@ -117,40 +117,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[])
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
int arg_idx = 4; int arg_idx = 4;
params.num_dim_spatial = num_dim_spatial; params.num_dim_spatial_ = num_dim_spatial;
params.N = std::stoi(argv[arg_idx++]); params.N_ = std::stoi(argv[arg_idx++]);
params.K = std::stoi(argv[arg_idx++]); params.K_ = std::stoi(argv[arg_idx++]);
params.C = std::stoi(argv[arg_idx++]); params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths.resize(num_dim_spatial); params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_spatial_lengths.resize(num_dim_spatial); params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_strides.resize(num_dim_spatial); params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_dilations.resize(num_dim_spatial); params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_left_pads.resize(num_dim_spatial); params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_left_pads[i] = std::stoi(argv[arg_idx++]); params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_right_pads.resize(num_dim_spatial); params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_right_pads[i] = std::stoi(argv[arg_idx++]); params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
} }
return params; return params;
...@@ -162,9 +162,9 @@ int main(int argc, char* argv[]) ...@@ -162,9 +162,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
const int num_dim_spatial = 2; const int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -173,7 +173,7 @@ int main(int argc, char* argv[]) ...@@ -173,7 +173,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
if(argc >= 5) if(argc >= 5)
...@@ -181,21 +181,21 @@ int main(int argc, char* argv[]) ...@@ -181,21 +181,21 @@ int main(int argc, char* argv[])
params = ParseConvParams(argc, argv); params = ParseConvParams(argc, argv);
} }
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims), input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths), std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths)); std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K), std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims), filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths), std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths)); std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths(); const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K)}; static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims), output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
...@@ -209,7 +209,7 @@ int main(int argc, char* argv[]) ...@@ -209,7 +209,7 @@ int main(int argc, char* argv[])
// bias: assume contiguous 1d vector // bias: assume contiguous 1d vector
Tensor<OutDataType> bias( Tensor<OutDataType> bias(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K_)})));
// residual: assume same layout as output tensor // residual: assume same layout as output tensor
Tensor<OutDataType> residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); Tensor<OutDataType> residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
...@@ -259,16 +259,16 @@ int main(int argc, char* argv[]) ...@@ -259,16 +259,16 @@ int main(int argc, char* argv[])
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()), static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(resi_device_buf.GetDeviceBuffer()), static_cast<const OutDataType*>(resi_device_buf.GetDeviceBuffer()),
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
...@@ -280,20 +280,20 @@ int main(int argc, char* argv[]) ...@@ -280,20 +280,20 @@ int main(int argc, char* argv[])
"not support this problem"); "not support this problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = std::size_t num_btype =
get_btype<InDataType, WeiDataType, OutDataType>(params.N, get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
params.C, params.C_,
params.K, params.K_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths) + output_spatial_lengths) +
sizeof(OutDataType) * (params.K) + sizeof(OutDataType) * (params.K_) +
sizeof(OutDataType) * sizeof(OutDataType) *
(params.N * params.K * output_spatial_lengths[0] * output_spatial_lengths[1]); (params.N_ * params.K_ * output_spatial_lengths[0] * output_spatial_lengths[1]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
...@@ -310,17 +310,18 @@ int main(int argc, char* argv[]) ...@@ -310,17 +310,18 @@ int main(int argc, char* argv[])
host_output, host_output,
bias, bias,
residual, residual,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return 0;
} }
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util) target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
...@@ -43,10 +43,10 @@ template <ck::index_t NumDimSpatial> ...@@ -43,10 +43,10 @@ template <ck::index_t NumDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device:: using DeviceConvNDFwdInstance = ck::tensor_operation::device::
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
// clang-format off // clang-format off
InDataType, // InDataType, //
WeiDataType, // WeiDataType, //
OutDataType, // OutDataType, //
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
...@@ -110,7 +110,7 @@ void print_use_msg() ...@@ -110,7 +110,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -137,40 +137,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha ...@@ -137,40 +137,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
int arg_idx = 5; int arg_idx = 5;
params.num_dim_spatial = num_dim_spatial; params.num_dim_spatial_ = num_dim_spatial;
params.N = std::stoi(argv[arg_idx++]); params.N_ = std::stoi(argv[arg_idx++]);
params.K = std::stoi(argv[arg_idx++]); params.K_ = std::stoi(argv[arg_idx++]);
params.C = std::stoi(argv[arg_idx++]); params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths.resize(num_dim_spatial); params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_spatial_lengths.resize(num_dim_spatial); params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_strides.resize(num_dim_spatial); params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_dilations.resize(num_dim_spatial); params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_left_pads.resize(num_dim_spatial); params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_left_pads[i] = std::stoi(argv[arg_idx++]); params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_right_pads.resize(num_dim_spatial); params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_right_pads[i] = std::stoi(argv[arg_idx++]); params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
} }
return params; return params;
...@@ -182,9 +182,9 @@ int main(int argc, char* argv[]) ...@@ -182,9 +182,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -193,7 +193,7 @@ int main(int argc, char* argv[]) ...@@ -193,7 +193,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
} }
...@@ -202,21 +202,21 @@ int main(int argc, char* argv[]) ...@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
params = parse_conv_params(num_dim_spatial, argc, argv); params = parse_conv_params(num_dim_spatial, argc, argv);
} }
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims), input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths), std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths)); std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K), std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims), filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths), std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths)); std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths(); const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K)}; static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims), output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
...@@ -256,16 +256,16 @@ int main(int argc, char* argv[]) ...@@ -256,16 +256,16 @@ int main(int argc, char* argv[])
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
...@@ -277,16 +277,16 @@ int main(int argc, char* argv[]) ...@@ -277,16 +277,16 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = get_btype<InDataType, WeiDataType, OutDataType>( std::size_t num_btype = get_btype<InDataType, WeiDataType, OutDataType>(
params.N, params.N_,
params.C, params.C_,
params.K, params.K_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths); output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -302,18 +302,18 @@ int main(int argc, char* argv[]) ...@@ -302,18 +302,18 @@ int main(int argc, char* argv[])
auto ref_argument = ref_conv.MakeArgument(input, auto ref_argument = ref_conv.MakeArgument(input,
weights, weights,
host_output, host_output,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
}; };
switch(num_dim_spatial) switch(num_dim_spatial)
...@@ -338,4 +338,5 @@ int main(int argc, char* argv[]) ...@@ -338,4 +338,5 @@ int main(int argc, char* argv[])
} }
} }
} }
return 0;
} }
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
...@@ -39,10 +39,10 @@ template <ck::index_t NumDimSpatial> ...@@ -39,10 +39,10 @@ template <ck::index_t NumDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device:: using DeviceConvNDFwdInstance = ck::tensor_operation::device::
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
// clang-format off // clang-format off
InDataType, // InDataType, //
WeiDataType, // WeiDataType, //
OutDataType, // OutDataType, //
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
...@@ -107,7 +107,7 @@ void print_use_msg() ...@@ -107,7 +107,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -134,40 +134,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha ...@@ -134,40 +134,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
int arg_idx = 5; int arg_idx = 5;
params.num_dim_spatial = num_dim_spatial; params.num_dim_spatial_ = num_dim_spatial;
params.N = std::stoi(argv[arg_idx++]); params.N_ = std::stoi(argv[arg_idx++]);
params.K = std::stoi(argv[arg_idx++]); params.K_ = std::stoi(argv[arg_idx++]);
params.C = std::stoi(argv[arg_idx++]); params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths.resize(num_dim_spatial); params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_spatial_lengths.resize(num_dim_spatial); params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_strides.resize(num_dim_spatial); params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_dilations.resize(num_dim_spatial); params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_left_pads.resize(num_dim_spatial); params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_left_pads[i] = std::stoi(argv[arg_idx++]); params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
} }
params.input_right_pads.resize(num_dim_spatial); params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_right_pads[i] = std::stoi(argv[arg_idx++]); params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
} }
return params; return params;
...@@ -179,9 +179,9 @@ int main(int argc, char* argv[]) ...@@ -179,9 +179,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -190,7 +190,7 @@ int main(int argc, char* argv[]) ...@@ -190,7 +190,7 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
} }
...@@ -199,21 +199,21 @@ int main(int argc, char* argv[]) ...@@ -199,21 +199,21 @@ int main(int argc, char* argv[])
params = parse_conv_params(num_dim_spatial, argc, argv); params = parse_conv_params(num_dim_spatial, argc, argv);
} }
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims), input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths), std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths)); std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K), std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C)}; static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims), filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths), std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths)); std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths(); const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N), std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K)}; static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims), output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
...@@ -255,16 +255,16 @@ int main(int argc, char* argv[]) ...@@ -255,16 +255,16 @@ int main(int argc, char* argv[])
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
...@@ -276,16 +276,16 @@ int main(int argc, char* argv[]) ...@@ -276,16 +276,16 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = get_flops( std::size_t flop = get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = std::size_t num_btype =
get_btype<InDataType, WeiDataType, OutDataType>(params.N, get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
params.C, params.C_,
params.K, params.K_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
output_spatial_lengths); output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -301,18 +301,23 @@ int main(int argc, char* argv[]) ...@@ -301,18 +301,23 @@ int main(int argc, char* argv[])
auto ref_argument = ref_conv.MakeArgument(input, auto ref_argument = ref_conv.MakeArgument(input,
weights, weights,
host_output, host_output,
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
ck::utils::check_err( return ck::utils::check_err(device_output.mData,
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); host_output.mData,
"Error: incorrect results!",
1e-5f,
1e-4f)
? 0
: 1;
}; };
switch(num_dim_spatial) switch(num_dim_spatial)
...@@ -337,4 +342,5 @@ int main(int argc, char* argv[]) ...@@ -337,4 +342,5 @@ int main(int argc, char* argv[])
} }
} }
} }
return 0;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment