Commit 2c1ed8b2 authored by Anthony Chang's avatar Anthony Chang
Browse files

Merge remote-tracking branch 'upstream/develop' into gemm-layernorm-4

parents b86b318b 56adf7e9
...@@ -7,7 +7,6 @@ def show_node_info() { ...@@ -7,7 +7,6 @@ def show_node_info() {
echo "NODE_NAME = \$NODE_NAME" echo "NODE_NAME = \$NODE_NAME"
lsb_release -sd lsb_release -sd
uname -r uname -r
cat /sys/module/amdgpu/version
ls /opt/ -la ls /opt/ -la
""" """
} }
...@@ -100,35 +99,45 @@ def buildHipClangJob(Map conf=[:]){ ...@@ -100,35 +99,45 @@ def buildHipClangJob(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
try { gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
retimage = docker.build("${image}", dockerArgs + '.') if (params.USE_DOCKERFILE){
withDockerContainer(image: image, args: dockerOpts) { try {
timeout(time: 5, unit: 'MINUTES') retimage = docker.build("${image}", dockerArgs + '.')
{ withDockerContainer(image: image, args: dockerOpts) {
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
} }
} }
} catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ echo "The job was cancelled or aborted"
echo "The job was cancelled or aborted" throw e
throw e }
} catch(Exception ex) {
catch(Exception ex) { retimage = docker.build("${image}", dockerArgs + "--no-cache .")
retimage = docker.build("${image}", dockerArgs + "--no-cache .") withDockerContainer(image: image, args: dockerOpts) {
withDockerContainer(image: image, args: dockerOpts) { timeout(time: 5, unit: 'MINUTES')
timeout(time: 5, unit: 'MINUTES') {
{ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' }
} }
} }
} }
else{
timeout(time: 3, unit: 'HOURS'){
retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull()
image="b56f8ac0d6ea"
sh "docker images"
}
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 5, unit: 'HOURS') timeout(time: 5, unit: 'HOURS')
{ {
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
cmake_build(conf) cmake_build(conf)
} }
} }
...@@ -181,61 +190,92 @@ def runCKProfiler(Map conf=[:]){ ...@@ -181,61 +190,92 @@ def runCKProfiler(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
try { gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
retimage = docker.build("${image}", dockerArgs + '.') if (params.USE_DOCKERFILE){
withDockerContainer(image: image, args: dockerOpts) { try {
timeout(time: 5, unit: 'MINUTES') retimage = docker.build("${image}", dockerArgs + '.')
{ withDockerContainer(image: image, args: dockerOpts) {
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
} }
} }
} catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ echo "The job was cancelled or aborted"
echo "The job was cancelled or aborted" throw e
throw e }
} catch(Exception ex) {
catch(Exception ex) { retimage = docker.build("${image}", dockerArgs + "--no-cache .")
retimage = docker.build("${image}", dockerArgs + "--no-cache .") withDockerContainer(image: image, args: dockerOpts) {
withDockerContainer(image: image, args: dockerOpts) { timeout(time: 5, unit: 'MINUTES')
timeout(time: 5, unit: 'MINUTES') {
{ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' }
} }
} }
} }
else{
timeout(time: 3, unit: 'HOURS'){
retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull()
image="b56f8ac0d6ea"
sh "docker images"
}
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 5, unit: 'HOURS') timeout(time: 5, unit: 'HOURS')
{ {
cmake_build(conf) cmake_build(conf)
dir("script"){ dir("script"){
def perf_log = "perf_gemm_${gpu_arch}.log" //run gemm performance tests
sh "rm -f ${perf_log}" def gemm_log = "perf_gemm_${gpu_arch}.log"
sh "echo Branch name: ${env.BRANCH_NAME} > ${perf_log}" sh "rm -f ${gemm_log}"
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${perf_log}" sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${perf_log}" sh "echo Node name: ${NODE_NAME} >> ${gemm_log}"
sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${perf_log}" sh "echo GPU_arch name: ${gpu_arch} >> ${gemm_log}"
sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${perf_log}" sh "rocminfo | grep 'Compute Unit:' >> ${gemm_log} "
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log}" sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${perf_log}" sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
//results will be parsed, stored, and analyzed within the python script sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
//the script will return 0 if the performance criteria are met sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
//or return 1 if the criteria are not met sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts "${perf_log}" sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
sh "python3 parse_perf_data.py ${perf_log} " sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
archiveArtifacts "${gemm_log}"
sh "python3 parse_perf_data.py ${gemm_log} "
//run resnet50 test
def resnet_log = "perf_resnet50_${gpu_arch}.log"
sh "rm -f ${resnet_log}"
sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}"
sh "echo Node name: ${NODE_NAME} >> ${resnet_log}"
sh "echo GPU_arch name: ${gpu_arch} >> ${resnet_log}"
sh "rocminfo | grep 'Compute Unit:' >> ${resnet_log} "
sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}"
sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}"
//first run tests with N=256
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}"
//then run with N=4
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}"
archiveArtifacts "${resnet_log}"
//the script will put the results from N=256 and N=4 runs into separate tables
sh "python3 parse_perf_data.py ${resnet_log} "
} }
} }
} }
...@@ -265,9 +305,21 @@ pipeline { ...@@ -265,9 +305,21 @@ pipeline {
options { options {
parallelsAlwaysFailFast() parallelsAlwaysFailFast()
} }
// environment{ parameters {
// variable = value booleanParam(
// } name: "USE_DOCKERFILE",
defaultValue: true,
description: "")
}
environment{
dbuser = "${dbuser}"
dbpassword = "${dbpassword}"
dbsship = "${dbsship}"
dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}"
status_wrapper_creds = "${status_wrapper_creds}"
}
stages{ stages{
stage("Static checks") { stage("Static checks") {
parallel{ parallel{
...@@ -282,30 +334,6 @@ pipeline { ...@@ -282,30 +334,6 @@ pipeline {
// buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug') // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
// } // }
// } // }
// we will build and run ckProfiler release version later, during the performance test stage
//stage('Build Profiler: Release, gfx908')
//{
// agent { label rocmnode("nogpu")}
// environment{
// setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
// }
// steps{
// buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
// }
//}
//stage('Build Profiler: Debug, gfx908')
//{
// agent { label rocmnode("nogpu")}
// environment{
// setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
// }
// steps{
// // until we stabilize debug build due to compiler crashes
// catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE') {
// buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Debug')
// }
// }
//}
stage('Clang Format') { stage('Clang Format') {
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
environment{ environment{
...@@ -333,12 +361,11 @@ pipeline { ...@@ -333,12 +361,11 @@ pipeline {
{ {
agent{ label rocmnode("gfx908")} agent{ label rocmnode("gfx908")}
environment{ environment{
setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx900 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
} }
steps{ steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908")
} }
} }
stage("Run Tests: gfx90a") stage("Run Tests: gfx90a")
{ {
...@@ -347,11 +374,9 @@ pipeline { ...@@ -347,11 +374,9 @@ pipeline {
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
} }
steps{ steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a")
} }
} }
} }
} }
stage("Client App") stage("Client App")
...@@ -380,33 +405,37 @@ pipeline { ...@@ -380,33 +405,37 @@ pipeline {
agent{ label rocmnode("gfx908")} agent{ label rocmnode("gfx908")}
environment{ environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
dbuser = "${dbuser}"
dbpassword = "${dbpassword}"
dbsship = "${dbsship}"
dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}"
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908")
}
}
stage("Run ckProfiler: gfx90a")
{
agent{ label rocmnode("gfx90a")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a")
} }
} }
} }
} }
/* enable after the cmake file supports packaging
// enable after the cmake file supports packaging stage("Packages") {
// stage("Packages") { when {
// when { expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA }
// expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA } }
// } parallel {
// parallel { stage("Package /opt/rocm") {
// stage("Package /opt/rocm") { agent{ label rocmnode("nogpu") }
// agent{ label rocmnode("nogpu") } steps{
// steps{ buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a")
// buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a") }
// } }
// } }
// } }
// } */
} }
} }
...@@ -27,28 +27,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -27,28 +27,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t; using ADataType = F16;
using BDataType = ck::half_t; using BDataType = F16;
using CDataType = ck::half_t; using AccDataType = F32;
using AccDataType = float; using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = Row;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = Col;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = Row;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -69,7 +70,11 @@ int main(int argc, char* argv[]) ...@@ -69,7 +70,11 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideC = 4096;
if(argc == 4) if(argc == 1)
{
// use default case
}
else if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -93,7 +98,7 @@ int main(int argc, char* argv[]) ...@@ -93,7 +98,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
......
...@@ -3,83 +3,103 @@ ...@@ -3,83 +3,103 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" #include "reference_gemm.hpp"
#include "reference_gemm_bias_activation.hpp" #include "gemm_specialization.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using ADataType = ck::half_t; using F16 = ck::half_t;
using BDataType = ck::half_t; using F32 = float;
using CDataType = ck::half_t;
using AccDataType = float; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CLayout = ck::tensor_layout::gemm::RowMajor;
// C = A * B
using AElementOp = ck::tensor_operation::element_wise::PassThrough; // E = Relu(C + D);
using BElementOp = ck::tensor_operation::element_wise::PassThrough; struct AddRelu
using CElementOp = ck::tensor_operation::element_wise::AddRelu; {
__host__ __device__ void
// clang-format off operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation< {
ADataType, // ADataType const ck::half_t x = c + d;
BDataType, // BDataType
CDataType, // CDataType e = x > 0 ? x : 0;
AccDataType, // AccDataType }
ALayout, // ALayout };
BLayout, // BLayout
CLayout, // CLayout using ADataType = F16;
AElementOp, // AElementwiseOperation using BDataType = F16;
BElementOp, // BElementwiseOperation using AccDataType = F32;
CElementOp, // CElementwiseOperation using CShuffleDataType = F16;
256, // BlockSize using DDataType = F16;
256, // MPerBlock using DsDataType = ck::Tuple<DDataType>;
128, // NPerBlock using EDataType = F16;
4, // K0PerBlock
8, // K1 using ALayout = Row;
32, // MPerXDL using BLayout = Col;
32, // NPerXDL using ELayout = Row;
4, // MXdlPerWave
2, // NXdlPerWave using AElementOp = PassThrough;
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 using BElementOp = PassThrough;
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder using CDEElementOp = AddRelu;
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1 using DeviceOpInstance =
true, // ABlockLdsAddExtraM ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 BLayout,
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder ELayout,
S<1, 0, 2>, // BBlockTransferSrcAccessOrder ADataType,
2, // BBlockTransferSrcVectorDim BDataType,
8, // BBlockTransferSrcScalarPerVector AccDataType,
8, // BBlockTransferDstScalarPerVector_K1 CShuffleDataType,
true, // BBlockLdsAddExtraN DsDataType,
1, // CShuffleMXdlPerWavePerShuffle EDataType,
1, // CShuffleNXdlPerWavePerShuffle AElementOp,
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl BElementOp,
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl CDEElementOp,
// clang-format on GemmDefault,
1,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation<ADataType, 256,
BDataType, 256,
CDataType, 128,
AElementOp, 32,
BElementOp, 8,
CElementOp>; 8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -94,9 +114,13 @@ int main(int argc, char* argv[]) ...@@ -94,9 +114,13 @@ int main(int argc, char* argv[])
ck::index_t StrideA = 4096; ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideE = 4096;
if(argc == 4) if(argc == 1)
{
// use default case
}
else if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -114,14 +138,14 @@ int main(int argc, char* argv[]) ...@@ -114,14 +138,14 @@ int main(int argc, char* argv[])
StrideA = std::stoi(argv[7]); StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]); StrideE = std::stoi(argv[9]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0); exit(0);
} }
...@@ -141,17 +165,14 @@ int main(int argc, char* argv[]) ...@@ -141,17 +165,14 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
// c0_n[n]
Tensor<CDataType> c0_n(HostTensorDescriptor(
std::vector<std::size_t>({static_cast<std::size_t>(N)}), std::vector<std::size_t>({1})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "c0_n: " << c0_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -159,59 +180,59 @@ int main(int argc, char* argv[]) ...@@ -159,59 +180,59 @@ int main(int argc, char* argv[])
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
c0_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
c0_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); d_m_n_device_buf.ToDevice(d_m_n.mData.data());
c0_n_device_buf.ToDevice(c0_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), auto argument =
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c0_n_device_buf.GetDeviceBuffer()), b_k_n_device_buf.GetDeviceBuffer(),
M, std::array<const void*, 1>{d_m_n_device_buf.GetDeviceBuffer()},
N, e_m_n_device_buf.GetDeviceBuffer(),
K, M,
StrideA, N,
StrideB, K,
StrideC, StrideA,
a_element_op, StrideB,
b_element_op, std::array<ck::index_t, 1>{0},
c_element_op); StrideE,
a_element_op,
if(!gemm.IsSupportedArgument(argument)) b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! this device_op instance does not support this problem");
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(CDataType) * N; sizeof(EDataType) * M * N + sizeof(EDataType) * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -220,19 +241,37 @@ int main(int argc, char* argv[]) ...@@ -220,19 +241,37 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
# Instructions for ```example_gemm_xdl_bias_relu_add``` # Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16```
## Run ```example_gemm_xdl_bias_relu_add``` ## Run ```example_gemm_add_add_fastgelu_xdl_fp16```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: time kernel (0=no, 1=yes)
#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 ./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
``` ```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c_grid_desc_m_n_{ 3840, 4096}
arg.c0_grid_desc_m_n_{ 3840, 4096}
arg.c1_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up Warm up 1 time
Start running 5 times... Start running 10 times...
Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
``` ```
...@@ -3,84 +3,60 @@ ...@@ -3,84 +3,60 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" #include "reference_gemm.hpp"
#include "reference_gemm_bias_activation_add.hpp" #include "gemm_specialization.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using ADataType = ck::half_t; using F16 = ck::half_t;
using BDataType = ck::half_t; using F32 = float;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using CElementOp = ck::tensor_operation::element_wise::AddReluAdd;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using D0Layout = Row;
using D1Layout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle
ADataType, // ADataType //######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
BDataType, // BDataType //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
CDataType, // CDataType //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
AccDataType, // AccDataType //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ALayout, // ALayout < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
BLayout, // BLayout
CLayout, // CLayout
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmBiasActivationAdd<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -94,16 +70,21 @@ int main(int argc, char* argv[]) ...@@ -94,16 +70,21 @@ int main(int argc, char* argv[])
ck::index_t StrideA = 4096; ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideD0 = 0;
ck::index_t StrideC1 = 4096; ck::index_t StrideD1 = 4096;
ck::index_t StrideE = 4096;
if(argc == 4) if(argc == 1)
{
// use default case
}
else if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 12)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -115,15 +96,17 @@ int main(int argc, char* argv[]) ...@@ -115,15 +96,17 @@ int main(int argc, char* argv[])
StrideA = std::stoi(argv[7]); StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]); StrideD0 = std::stoi(argv[9]);
StrideC1 = std::stoi(argv[10]); StrideD1 = std::stoi(argv[10]);
StrideE = std::stoi(argv[11]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, "
"StrideE\n");
exit(0); exit(0);
} }
...@@ -143,21 +126,16 @@ int main(int argc, char* argv[]) ...@@ -143,21 +126,16 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
// c0_n[n] Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<CDataType> c0_n(HostTensorDescriptor(
std::vector<std::size_t>({static_cast<std::size_t>(N)}), std::vector<std::size_t>({1})));
// c1_m_n[m ,n]
Tensor<CDataType> c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "c0_n: " << c0_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -165,92 +143,102 @@ int main(int argc, char* argv[]) ...@@ -165,92 +143,102 @@ int main(int argc, char* argv[])
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
c0_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
c1_m_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
c0_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
c0_n_device_buf.ToDevice(c0_n.mData.data()); d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto invoker = gemm.MakeInvoker(); auto argument =
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), b_k_n_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c0_n_device_buf.GetDeviceBuffer()), d1_m_n_device_buf.GetDeviceBuffer()},
static_cast<CDataType*>(c1_m_n_device_buf.GetDeviceBuffer()), e_m_n_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideC1, StrideE,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!device_op.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! this device_op instance does not support this problem");
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(CDataType) * N + sizeof(D0DataType) * N + sizeof(D1DataType) * M * N +
sizeof(CDataType) * M * N; sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl; << device_op.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
Tensor<AccDataType> c_m_n(HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k, auto ref_argument =
b_k_n, ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
c_m_n_host_result,
c0_n,
c1_m_n,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
add_example_executable(example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp)
...@@ -224,10 +224,10 @@ int main(int argc, char* argv[]) ...@@ -224,10 +224,10 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
break; break;
default: default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
......
...@@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; ...@@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
constexpr bool PropagateNan = true; constexpr bool PropagateNan = true;
constexpr bool OutputIndex = false; constexpr bool OutputIndex = false;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType, using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
...@@ -247,6 +247,13 @@ int main(int argc, char* argv[]) ...@@ -247,6 +247,13 @@ int main(int argc, char* argv[])
DeviceMem out_index_dev(indicesSizeInBytes); DeviceMem out_index_dev(indicesSizeInBytes);
InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op;
std::tie(in_elementwise_op, acc_elementwise_op) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length));
if(args.do_verification) if(args.do_verification)
{ {
ReductionHost<InDataType, ReductionHost<InDataType,
...@@ -261,8 +268,13 @@ int main(int argc, char* argv[]) ...@@ -261,8 +268,13 @@ int main(int argc, char* argv[])
OutputIndex> OutputIndex>
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run( hostReduce.Run(alpha,
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); in.mData.data(),
beta,
out_ref.mData.data(),
out_indices_ref.mData.data(),
in_elementwise_op,
acc_elementwise_op);
}; };
std::vector<ck::index_t> i_inLengths; std::vector<ck::index_t> i_inLengths;
...@@ -277,20 +289,19 @@ int main(int argc, char* argv[]) ...@@ -277,20 +289,19 @@ int main(int argc, char* argv[])
auto reduce = DeviceReduceInstance{}; auto reduce = DeviceReduceInstance{};
auto argument_ptr = reduce.MakeArgumentPointer( auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths,
i_inLengths, i_inStrides,
i_inStrides, i_outLengths,
i_outLengths, i_outStrides,
i_outStrides, reduceDims,
reduceDims, alpha,
alpha, beta,
beta, in_dev.GetDeviceBuffer(),
in_dev.GetDeviceBuffer(), nullptr,
nullptr, out_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(), out_index_dev.GetDeviceBuffer(),
out_index_dev.GetDeviceBuffer(), in_elementwise_op,
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)}, acc_elementwise_op);
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
if(!reduce.IsSupportedArgument(argument_ptr.get())) if(!reduce.IsSupportedArgument(argument_ptr.get()))
{ {
......
...@@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; ...@@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
constexpr bool PropagateNan = true; constexpr bool PropagateNan = true;
constexpr bool OutputIndex = false; constexpr bool OutputIndex = false;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<AccDataType, AccDataType>; using PassThroughOp = tensor_operation::element_wise::PassThrough;
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType, using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
AccDataType, AccDataType,
...@@ -184,6 +184,13 @@ int main(int argc, char* argv[]) ...@@ -184,6 +184,13 @@ int main(int argc, char* argv[])
if(beta != 0.0f) if(beta != 0.0f)
out_dev.ToDevice(out.mData.data()); out_dev.ToDevice(out.mData.data());
InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op;
std::tie(in_elementwise_op, acc_elementwise_op) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length));
if(do_verify) if(do_verify)
{ {
ReductionHost<InOutDataType, ReductionHost<InOutDataType,
...@@ -198,7 +205,13 @@ int main(int argc, char* argv[]) ...@@ -198,7 +205,13 @@ int main(int argc, char* argv[])
OutputIndex> OutputIndex>
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims); hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha, in_1.mData.data(), beta, out_ref.mData.data(), nullptr); hostReduce.Run(alpha,
in_1.mData.data(),
beta,
out_ref.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
}; };
std::vector<ck::index_t> i_inLengths_1; std::vector<ck::index_t> i_inLengths_1;
...@@ -217,20 +230,19 @@ int main(int argc, char* argv[]) ...@@ -217,20 +230,19 @@ int main(int argc, char* argv[])
auto reduce_1 = DeviceReduceInstance_1{}; auto reduce_1 = DeviceReduceInstance_1{};
auto argument_ptr_1 = reduce_1.MakeArgumentPointer( auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1,
i_inLengths_1, i_inStrides_1,
i_inStrides_1, i_inLengths_2,
i_inLengths_2, i_inStrides_2,
i_inStrides_2, reduceDims_1,
reduceDims_1, 1.0f,
1.0f, 0.0f,
0.0f, in_1_dev.GetDeviceBuffer(),
in_1_dev.GetDeviceBuffer(), nullptr,
nullptr, in_2_dev.GetDeviceBuffer(),
in_2_dev.GetDeviceBuffer(), nullptr,
nullptr, in_elementwise_op,
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)}, PassThroughOp{});
PassThroughOp{});
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{ {
...@@ -243,20 +255,19 @@ int main(int argc, char* argv[]) ...@@ -243,20 +255,19 @@ int main(int argc, char* argv[])
auto reduce_2 = DeviceReduceInstance_2{}; auto reduce_2 = DeviceReduceInstance_2{};
auto argument_ptr_2 = reduce_2.MakeArgumentPointer( auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2,
i_inLengths_2, i_inStrides_2,
i_inStrides_2, i_outLengths,
i_outLengths, i_outStrides,
i_outStrides, reduceDims_2,
reduceDims_2, alpha,
alpha, beta,
beta, in_2_dev.GetDeviceBuffer(),
in_2_dev.GetDeviceBuffer(), nullptr,
nullptr, out_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(), nullptr,
nullptr, PassThroughOp{},
PassThroughOp{}, acc_elementwise_op);
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
if(!reduce_2.IsSupportedArgument(argument_ptr_2.get())) if(!reduce_2.IsSupportedArgument(argument_ptr_2.get()))
{ {
......
...@@ -31,16 +31,15 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -31,16 +31,15 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads, const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/) const std::array<ck::index_t, 2>& /*in_right_pads*/)
{ {
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
const InElementwiseOperation in_elementwise_op(divider); auto elementwise_ops =
const AccElementwiseOperation acc_elementwise_op(divider); ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
...@@ -48,7 +47,7 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -48,7 +47,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue(); auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
...@@ -86,7 +85,7 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -86,7 +85,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue(); auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -41,9 +40,8 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -41,9 +40,8 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using DsReduceOp = ck::Tuple<ck::reduce::Max<ReduceAccDataType>>; using DsReduceOp = ck::Tuple<ck::reduce::Max>;
using DsElementOp = ck::Tuple< using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>>;
using DGlobalMemOp = using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>; ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
...@@ -236,10 +234,14 @@ int main(int argc, char* argv[]) ...@@ -236,10 +234,14 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue(); ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n)); {
ReduceAccDataType curr_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
d_reduce_op(d_acc, curr_val);
};
d_m_host_result(m) = d_acc; d_m_host_result(m) = d_acc;
} }
......
...@@ -41,18 +41,15 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -41,18 +41,15 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<ReduceAccDataType>; using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>; using D1ReduceOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using UnaryIdenticElementOp = using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnaryDivElementOp = using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>; using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using UnarySquareElementOp = using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp = using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
...@@ -67,7 +64,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_ ...@@ -67,7 +64,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -204,8 +201,8 @@ int main(int argc, char* argv[]) ...@@ -204,8 +201,8 @@ int main(int argc, char* argv[])
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer())); static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOp{}; auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOp{M, M}; auto dxs_out_element_op = DxsOutElementOps{N, N};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
...@@ -261,14 +258,14 @@ int main(int argc, char* argv[]) ...@@ -261,14 +258,14 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetIdentityValue(); auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
float d1_acc = d1_reduce_op.GetIdentityValue(); auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n)); auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
float d0_val = 0; ReduceAccDataType d0_val;
float d1_val = 0; ReduceAccDataType d1_val;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
......
...@@ -39,16 +39,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -39,16 +39,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<ReduceAccDataType>; using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>; using D1ReduceOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>; using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using UnaryIdenticElementOp = using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using UnarySquareElementOp = using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>; using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DGlobalMemOp = using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
...@@ -63,7 +61,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc ...@@ -63,7 +61,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...@@ -206,8 +204,8 @@ int main(int argc, char* argv[]) ...@@ -206,8 +204,8 @@ int main(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
DxsInElementOp{}, DxsInElementOps{},
DxsOutElementOp{}, DxsOutElementOps{},
BatchCount); BatchCount);
if(!batched_gemm.IsSupportedArgument(argument)) if(!batched_gemm.IsSupportedArgument(argument))
...@@ -259,14 +257,15 @@ int main(int argc, char* argv[]) ...@@ -259,14 +257,15 @@ int main(int argc, char* argv[])
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetIdentityValue(); auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
float d1_acc = d1_reduce_op.GetIdentityValue(); auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float c_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n)); auto c_val =
float d0_val = 0; ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
float d1_val = 0; ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
UnaryIdenticElementOp{}(d0_val, c_val); UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val); UnarySquareElementOp{}(d1_val, c_val);
......
...@@ -42,8 +42,7 @@ using ABDataType = F16; ...@@ -42,8 +42,7 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise:: using Add = ck::tensor_operation::element_wise::Add;
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
......
...@@ -17,8 +17,7 @@ using ABDataType = F16; ...@@ -17,8 +17,7 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise:: using Add = ck::tensor_operation::element_wise::Add;
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
......
...@@ -42,8 +42,7 @@ using ABDataType = F16; ...@@ -42,8 +42,7 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise:: using Add = ck::tensor_operation::element_wise::Add;
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
......
...@@ -42,8 +42,7 @@ using ABDataType = F16; ...@@ -42,8 +42,7 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise:: using Add = ck::tensor_operation::element_wise::Add;
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
......
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp) add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp)
\ No newline at end of file target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util)
\ No newline at end of file
...@@ -297,52 +297,15 @@ int main(int argc, char* argv[]) ...@@ -297,52 +297,15 @@ int main(int argc, char* argv[])
split_k); split_k);
// alloc work space // alloc work space
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); float ave_time = 0.f;
float ave_time = 0.f; if(!conv->IsSupportedArgument(argument.get()))
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
{ {
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); std::cout << "wrong! device_conv with the specified compilation parameters does "
wei_work_space_device_buf.SetZero(); "not support this Conv problem"
argument = conv->MakeArgumentPointer( << std::endl;
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), return 1;
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
}
else
{
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
} }
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = ck::utils::conv::get_flops( std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment