Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
926008bc
Unverified
Commit
926008bc
authored
Jul 21, 2022
by
Chao Liu
Committed by
GitHub
Jul 21, 2022
Browse files
Merge branch 'develop' into batched_gemm_multiD
parents
e9652cbc
d8415a96
Changes
30
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1352 additions
and
663 deletions
+1352
-663
Jenkinsfile
Jenkinsfile
+183
-82
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+51
-41
example/28_grouped_gemm_bias/CMakeLists.txt
example/28_grouped_gemm_bias/CMakeLists.txt
+1
-0
example/28_grouped_gemm_bias/grouped_gemm_bias_xdl_fp16.cpp
example/28_grouped_gemm_bias/grouped_gemm_bias_xdl_fp16.cpp
+278
-0
example/CMakeLists.txt
example/CMakeLists.txt
+2
-1
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+0
-29
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+69
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+430
-258
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+0
-49
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+134
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+17
-24
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
+32
-13
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
+32
-14
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+32
-22
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
+30
-34
profiler/include/profile_batched_gemm_reduce_impl.hpp
profiler/include/profile_batched_gemm_reduce_impl.hpp
+10
-7
profiler/include/profile_conv_bwd_weight_impl.hpp
profiler/include/profile_conv_bwd_weight_impl.hpp
+3
-3
profiler/include/profile_convnd_bwd_data_impl.hpp
profiler/include/profile_convnd_bwd_data_impl.hpp
+2
-1
profiler/include/profile_convnd_bwd_weight_impl.hpp
profiler/include/profile_convnd_bwd_weight_impl.hpp
+2
-6
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+44
-79
No files found.
Jenkinsfile
View file @
926008bc
...
@@ -11,6 +11,12 @@ def show_node_info() {
...
@@ -11,6 +11,12 @@ def show_node_info() {
"""
"""
}
}
def
runShell
(
String
command
){
def
responseCode
=
sh
returnStatus:
true
,
script:
"${command} &> tmp.txt"
def
output
=
readFile
(
file:
"tmp.txt"
)
return
(
output
!=
""
)
}
def
cmake_build
(
Map
conf
=[:]){
def
cmake_build
(
Map
conf
=[:]){
def
compiler
=
conf
.
get
(
"compiler"
,
"/opt/rocm/bin/hipcc"
)
def
compiler
=
conf
.
get
(
"compiler"
,
"/opt/rocm/bin/hipcc"
)
...
@@ -60,7 +66,7 @@ def cmake_build(Map conf=[:]){
...
@@ -60,7 +66,7 @@ def cmake_build(Map conf=[:]){
"""
"""
def
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
def
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
// reduce parallelism when compiling, clang uses too much memory
// reduce parallelism when compiling, clang uses too much memory
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j\$(( \$(nproc) /
1
)) ${config_targets}"
)
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j\$(( \$(nproc) /
2
)) ${config_targets}"
)
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
def
cmd
=
conf
.
get
(
"cmd"
,
"""
def
cmd
=
conf
.
get
(
"cmd"
,
"""
...
@@ -113,7 +119,14 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -113,7 +119,14 @@ def buildHipClangJob(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
}
}
}
...
@@ -125,7 +138,14 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -125,7 +138,14 @@ def buildHipClangJob(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
" --no-cache ."
)
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
" --no-cache ."
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo |tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
}
}
}
...
@@ -133,7 +153,14 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -133,7 +153,14 @@ def buildHipClangJob(Map conf=[:]){
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
5
,
unit:
'HOURS'
)
timeout
(
time:
5
,
unit:
'HOURS'
)
{
{
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
cmake_build
(
conf
)
cmake_build
(
conf
)
}
}
}
}
...
@@ -145,7 +172,6 @@ def reboot(){
...
@@ -145,7 +172,6 @@ def reboot(){
build
job:
'reboot-slaves'
,
propagate:
false
,
parameters:
[
string
(
name:
'server'
,
value:
"${env.NODE_NAME}"
),]
build
job:
'reboot-slaves'
,
propagate:
false
,
parameters:
[
string
(
name:
'server'
,
value:
"${env.NODE_NAME}"
),]
}
}
def
buildHipClangJobAndReboot
(
Map
conf
=[:]){
def
buildHipClangJobAndReboot
(
Map
conf
=[:]){
try
{
try
{
buildHipClangJob
(
conf
)
buildHipClangJob
(
conf
)
...
@@ -162,7 +188,6 @@ def buildHipClangJobAndReboot(Map conf=[:]){
...
@@ -162,7 +188,6 @@ def buildHipClangJobAndReboot(Map conf=[:]){
}
}
}
}
def
runCKProfiler
(
Map
conf
=[:]){
def
runCKProfiler
(
Map
conf
=[:]){
show_node_info
()
show_node_info
()
...
@@ -189,7 +214,6 @@ def runCKProfiler(Map conf=[:]){
...
@@ -189,7 +214,6 @@ def runCKProfiler(Map conf=[:]){
}
}
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
...
@@ -197,7 +221,14 @@ def runCKProfiler(Map conf=[:]){
...
@@ -197,7 +221,14 @@ def runCKProfiler(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
}
}
}
...
@@ -209,89 +240,69 @@ def runCKProfiler(Map conf=[:]){
...
@@ -209,89 +240,69 @@ def runCKProfiler(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
" --no-cache ."
)
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
" --no-cache ."
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
}
}
}
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
5
,
unit:
'HOURS'
)
timeout
(
time:
24
,
unit:
'HOURS'
)
{
{
cmake_build
(
conf
)
cmake_build
(
conf
)
dir
(
"script"
){
dir
(
"script"
){
//run gemm performance tests
if
(
params
.
RUN_FULL_QA
){
def
gemm_log
=
"perf_gemm_${gpu_arch}.log"
def
qa_log
=
"qa_${gpu_arch}.log"
sh
"rm -f ${gemm_log}"
sh
"echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
sh
"echo Node name: ${NODE_NAME} >> ${gemm_log}"
sh
"echo GPU_arch name: ${gpu_arch} >> ${gemm_log}"
sh
"rocminfo | grep 'Compute Unit:' >> ${gemm_log} "
sh
"hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
if
(
params
.
USE_9110
){
if
(
params
.
USE_9110
){
sh
"echo Environment type: CI_9110 >> ${gemm_log
}"
sh
"./run_full_performance_tests.sh 1 QA_9110 ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME
}"
}
}
else
{
else
{
sh
"echo Environment type: CI_release >> ${gemm_log}"
sh
"./run_full_performance_tests.sh 1 QA_release ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME}"
}
}
sh
"/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
archiveArtifacts
"perf_gemm_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_resnet50_N256_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_resnet50_N4_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_bathced_gemm_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_grouped_gemm_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_fwd_conv_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_bwd_conv_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_fusion_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
archiveArtifacts
"perf_reduction_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
// stash perf files to master
sh
"./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_gemm_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_resnet50_N256_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_resnet50_N4_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_bathced_gemm_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_grouped_gemm_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_fwd_conv_${gpu_arch}.log"
sh
"./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
stash
name:
"perf_bwd_conv_${gpu_arch}.log"
//results will be parsed, stored, and analyzed within the python script
stash
name:
"perf_fusion_${gpu_arch}.log"
//the script will return 0 if the performance criteria are met
stash
name:
"perf_reduction_${gpu_arch}.log"
//or return 1 if the criteria are not met
//we will process results on the master node
archiveArtifacts
"${gemm_log}"
sh
"python3 process_perf_data.py ${gemm_log} "
//run resnet50 test
def
resnet256_log
=
"perf_resnet50_N256_${gpu_arch}.log"
sh
"rm -f ${resnet256_log}"
sh
"echo Branch name: ${env.BRANCH_NAME} > ${resnet256_log}"
sh
"echo Node name: ${NODE_NAME} >> ${resnet256_log}"
sh
"echo GPU_arch name: ${gpu_arch} >> ${resnet256_log}"
sh
"rocminfo | grep 'Compute Unit:' >> ${resnet256_log} "
sh
"hipcc --version | grep -e 'HIP version' >> ${resnet256_log}"
if
(
params
.
USE_9110
){
sh
"echo Environment type: CI_9110 >> ${resnet256_log}"
}
}
else
{
else
{
sh
"echo Environment type: CI_release >> ${resnet256_log}"
}
sh
"/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet256_log}"
//first run tests with N=256
sh
"./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet256_log}"
archiveArtifacts
"${resnet256_log}"
sh
"python3 process_perf_data.py ${resnet256_log} "
//then run with N=4
def
resnet4_log
=
"perf_resnet50_N4_${gpu_arch}.log"
sh
"rm -f ${resnet4_log}"
sh
"echo Branch name: ${env.BRANCH_NAME} > ${resnet4_log}"
sh
"echo Node name: ${NODE_NAME} >> ${resnet4_log}"
sh
"echo GPU_arch name: ${gpu_arch} >> ${resnet4_log}"
sh
"rocminfo | grep 'Compute Unit:' >> ${resnet4_log} "
sh
"hipcc --version | grep -e 'HIP version' >> ${resnet4_log}"
if
(
params
.
USE_9110
){
if
(
params
.
USE_9110
){
sh
"echo Environment type: CI_9110 >> ${resnet4_log
}"
sh
"./run_performance_tests.sh 0 CI_9110 ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME
}"
}
}
else
{
else
{
sh
"echo Environment type: CI_release >> ${resnet4_log
}"
sh
"./run_performance_tests.sh 0 CI_release ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME
}"
}
}
sh
"/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet4_log}"
archiveArtifacts
"perf_gemm_${gpu_arch}.log"
sh
"./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet4_log}"
archiveArtifacts
"perf_resnet50_N256_${gpu_arch}.log"
archiveArtifacts
"${resnet4_log}"
archiveArtifacts
"perf_resnet50_N4_${gpu_arch}.log"
sh
"python3 process_perf_data.py ${resnet4_log} "
// stash perf files to master
stash
name:
"perf_gemm_${gpu_arch}.log"
stash
name:
"perf_resnet50_N256_${gpu_arch}.log"
stash
name:
"perf_resnet50_N4_${gpu_arch}.log"
//we will process the results on the master node
}
}
}
}
}
}
}
...
@@ -299,7 +310,6 @@ def runCKProfiler(Map conf=[:]){
...
@@ -299,7 +310,6 @@ def runCKProfiler(Map conf=[:]){
return
retimage
return
retimage
}
}
def
runPerfTest
(
Map
conf
=[:]){
def
runPerfTest
(
Map
conf
=[:]){
try
{
try
{
runCKProfiler
(
conf
)
runCKProfiler
(
conf
)
...
@@ -316,8 +326,76 @@ def runPerfTest(Map conf=[:]){
...
@@ -316,8 +326,76 @@ def runPerfTest(Map conf=[:]){
}
}
}
}
def
process_results
(
Map
conf
=[:]){
env
.
HSA_ENABLE_SDMA
=
0
checkout
scm
def
image
=
"composable_kernels"
def
prefixpath
=
"/opt/rocm"
def
gpu_arch
=
conf
.
get
(
"gpu_arch"
,
"gfx908"
)
// Jenkins is complaining about the render group
def
dockerOpts
=
"--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1"
}
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' --build-arg compiler_version='release' "
def
variant
=
env
.
STAGE_NAME
def
retimage
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
try
{
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
}
catch
(
org
.
jenkinsci
.
plugins
.
workflow
.
steps
.
FlowInterruptedException
e
){
echo
"The job was cancelled or aborted"
throw
e
}
}
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
1
,
unit:
'HOURS'
){
try
{
dir
(
"script"
){
if
(
params
.
RUN_FULL_QA
){
// unstash perf files to master
unstash
"perf_gemm_${gpu_arch}.log"
unstash
"perf_resnet50_N256_${gpu_arch}.log"
unstash
"perf_resnet50_N4_${gpu_arch}.log"
unstash
"perf_bathced_gemm_${gpu_arch}.log"
unstash
"perf_grouped_gemm_${gpu_arch}.log"
unstash
"perf_fwd_conv_${gpu_arch}.log"
unstash
"perf_bwd_conv_${gpu_arch}.log"
unstash
"perf_fusion_${gpu_arch}.log"
unstash
"perf_reduction_${gpu_arch}.log"
sh
"./process_qa_data.sh ${gpu_arch}"
}
else
{
// unstash perf files to master
unstash
"perf_gemm_${gpu_arch}.log"
unstash
"perf_resnet50_N256_${gpu_arch}.log"
unstash
"perf_resnet50_N4_${gpu_arch}.log"
sh
"./process_perf_data.sh ${gpu_arch}"
}
}
}
catch
(
e
){
echo
"throwing error exception while processing performance test results"
echo
'Exception occurred: '
+
e
.
toString
()
throw
e
}
}
}
}
//launch develop branch daily at 23:00 in FULL_QA mode
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;USE_9110=true'''
:
""
pipeline
{
pipeline
{
agent
none
agent
none
triggers
{
cron
(
CRON_SETTINGS
)
}
options
{
options
{
parallelsAlwaysFailFast
()
parallelsAlwaysFailFast
()
}
}
...
@@ -325,7 +403,11 @@ pipeline {
...
@@ -325,7 +403,11 @@ pipeline {
booleanParam
(
booleanParam
(
name:
"USE_9110"
,
name:
"USE_9110"
,
defaultValue:
true
,
defaultValue:
true
,
description:
""
)
description:
"Select compiler version: 9110 (default) or release"
)
booleanParam
(
name:
"RUN_FULL_QA"
,
defaultValue:
false
,
description:
"Select whether to run small set of performance tests (default) or full QA"
)
}
}
environment
{
environment
{
dbuser
=
"${dbuser}"
dbuser
=
"${dbuser}"
...
@@ -438,6 +520,25 @@ pipeline {
...
@@ -438,6 +520,25 @@ pipeline {
}
}
}
}
}
}
stage
(
"Process Performance Test Results"
)
{
parallel
{
stage
(
"Process results for gfx908"
){
agent
{
label
'mici'
}
steps
{
process_results
(
gpu_arch:
"gfx908"
)
}
}
stage
(
"Process results for gfx90a"
){
agent
{
label
'mici'
}
steps
{
process_results
(
gpu_arch:
"gfx90a"
)
}
}
}
}
/* enable after the cmake file supports packaging
/* enable after the cmake file supports packaging
stage("Packages") {
stage("Packages") {
when {
when {
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
926008bc
...
@@ -29,34 +29,39 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -29,34 +29,39 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
F16
;
using
BDataType
=
ck
::
half_t
;
using
BDataType
=
F16
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
F32
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
Row
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
Col
;
using
C
Layout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
E
Layout
=
Row
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
C
DE
ElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmXdl
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
// clang-format off
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
>
;
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -81,11 +86,11 @@ int main(int argc, char* argv[])
...
@@ -81,11 +86,11 @@ int main(int argc, char* argv[])
int
group_count
=
rand
()
%
16
+
1
;
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_c
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
...
@@ -93,7 +98,11 @@ int main(int argc, char* argv[])
...
@@ -93,7 +98,11 @@ int main(int argc, char* argv[])
int
N
=
128
+
128
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
int
K
=
64
+
64
*
i
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
int
stride_A
=
K
;
int
stride_B
=
K
;
int
stride_C
=
N
;
gemm_descs
.
push_back
({
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
{}});
}
}
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
@@ -111,10 +120,9 @@ int main(int argc, char* argv[])
...
@@ -111,10 +120,9 @@ int main(int argc, char* argv[])
};
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
C
DataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
E
DataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
C
DataType
>>
c_device_tensors
;
std
::
vector
<
Tensor
<
E
DataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
...
@@ -131,25 +139,25 @@ int main(int argc, char* argv[])
...
@@ -131,25 +139,25 @@ int main(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
S
tride
A
,
ALayout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
K
_
,
gemm_
desc
s
[
i
].
s
tride
_A_
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
B
,
BLayout
{})));
gemm_
desc
s
[
i
].
K
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_B_
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
C
DataType
>
(
f_host_tensor_descriptor
(
c_host_tensors
.
push_back
(
Tensor
<
E
DataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
C
,
C
Layout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_C_
,
E
Layout
{})));
c_device_tensors
.
push_back
(
Tensor
<
C
DataType
>
(
f_host_tensor_descriptor
(
c_device_tensors
.
push_back
(
Tensor
<
E
DataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
C
,
C
Layout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_C_
,
E
Layout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_
shape
s
[
i
].
M
*
gemm_
shape
s
[
i
].
K
*
gemm_
shape
s
[
i
].
N
;
flop
+=
std
::
size_t
(
2
)
*
gemm_
desc
s
[
i
].
M
_
*
gemm_
desc
s
[
i
].
K
_
*
gemm_
desc
s
[
i
].
N
_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
C
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
sizeof
(
E
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -168,14 +176,14 @@ int main(int argc, char* argv[])
...
@@ -168,14 +176,14 @@ int main(int argc, char* argv[])
}
}
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
a_tensors_device
.
emplace_back
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
b_tensors_device
.
emplace_back
(
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpace
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
C
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
sizeof
(
E
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
...
@@ -187,14 +195,16 @@ int main(int argc, char* argv[])
...
@@ -187,14 +195,16 @@ int main(int argc, char* argv[])
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
C
DE
ElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_Ds
=
{};
// do GEMM
// do GEMM
auto
argument
=
auto
argument
=
gemm
.
MakeArgument
(
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
a_element_op
,
b_element_op
,
c_element_op
);
p_a
,
p_b
,
p_Ds
,
p_c
,
gemm_
desc
s
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -219,7 +229,7 @@ int main(int argc, char* argv[])
...
@@ -219,7 +229,7 @@ int main(int argc, char* argv[])
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
...
example/28_grouped_gemm_bias/CMakeLists.txt
0 → 100644
View file @
926008bc
add_example_executable
(
example_grouped_gemm_bias_xdl_fp16 grouped_gemm_bias_xdl_fp16.cpp
)
example/28_grouped_gemm_bias/grouped_gemm_bias_xdl_fp16.cpp
0 → 100644
View file @
926008bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
D0DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmXdl
// clang-format off
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
exit
(
0
);
}
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_ds
;
std
::
vector
<
void
*>
p_c
;
gemm_descs
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
+
256
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
int
stride_A
=
K
;
int
stride_B
=
K
;
int
stride_C
=
N
;
std
::
vector
<
ck
::
index_t
>
stride_Ds
=
{
0
};
gemm_descs
.
push_back
({
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
stride_Ds
});
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
D0DataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
e_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
e_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
d0_tensors
.
reserve
(
group_count
);
e_host_tensors
.
reserve
(
group_count
);
e_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
d0_tensors_device
,
e_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
d0_tensors_device
.
reserve
(
group_count
);
e_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
K_
,
gemm_descs
[
i
].
stride_A_
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_B_
,
BLayout
{})));
d0_tensors
.
push_back
(
Tensor
<
D0DataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_Ds_
[
0
],
ELayout
{})));
e_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_C_
,
ELayout
{})));
e_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_C_
,
ELayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
e_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M_
*
gemm_descs
[
i
].
K_
*
gemm_descs
[
i
].
N_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpace
()));
d0_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
D0DataType
)
*
d0_tensors
[
i
].
mDesc
.
GetElementSpace
()));
e_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
d0_tensors_device
[
i
]
->
ToDevice
(
d0_tensors
[
i
].
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_ds
.
push_back
({
d0_tensors_device
[
i
]
->
GetDeviceBuffer
()});
p_c
.
push_back
(
e_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_ds
,
p_c
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
e_tensors_device
[
i
]
->
FromDevice
(
e_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
e_host_tensors
[
i
],
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
gemm_descs
[
i
].
M_
;
++
m
)
{
for
(
int
n
=
0
;
n
<
gemm_descs
[
i
].
N_
;
++
n
)
{
cde_element_op
(
e_host_tensors
[
i
](
m
,
n
),
e_host_tensors
[
i
](
m
,
n
),
d0_tensors
[
i
](
m
,
n
));
}
}
pass
&=
ck
::
utils
::
check_err
(
e_device_tensors
[
i
].
mData
,
e_host_tensors
[
i
].
mData
);
}
}
return
pass
?
0
:
1
;
}
example/CMakeLists.txt
View file @
926008bc
...
@@ -46,4 +46,5 @@ add_subdirectory(24_batched_gemm_c_permute)
...
@@ -46,4 +46,5 @@ add_subdirectory(24_batched_gemm_c_permute)
add_subdirectory
(
25_gemm_bias_c_permute
)
add_subdirectory
(
25_gemm_bias_c_permute
)
add_subdirectory
(
26_contraction
)
add_subdirectory
(
26_contraction
)
add_subdirectory
(
27_layernorm
)
add_subdirectory
(
27_layernorm
)
add_subdirectory
(
28_grouped_gemm_bias
)
add_subdirectory
(
29_batched_gemm_multi_d
)
add_subdirectory
(
29_batched_gemm_multi_d
)
\ No newline at end of file
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
926008bc
...
@@ -12,12 +12,6 @@ namespace ck {
...
@@ -12,12 +12,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
struct
GemmShape
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
@@ -65,29 +59,6 @@ using DeviceGemmPtr = std::unique_ptr<DeviceGemm<ALayout,
...
@@ -65,29 +59,6 @@ using DeviceGemmPtr = std::unique_ptr<DeviceGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
0 → 100644
View file @
926008bc
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
GemmDesc
{
ck
::
index_t
M_
,
N_
,
K_
;
ck
::
index_t
stride_A_
,
stride_B_
,
stride_C_
;
std
::
vector
<
ck
::
index_t
>
stride_Ds_
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_ds
,
std
::
vector
<
void
*>&
p_e
,
std
::
vector
<
GemmDesc
>&
gemm_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
926008bc
This diff is collapsed.
Click to expand it.
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
926008bc
...
@@ -381,52 +381,3 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
...
@@ -381,52 +381,3 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
{
}
}
#if 1
// FIXME: remove
template
<
typename
T
>
float
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
float
l1_error
=
0
;
float
linf_error
=
-
1
;
float
linf_rel_error
=
-
1
;
float
linf_ref_value
=
0
,
linf_result_value
=
0
;
float
linf_rel_ref_value
=
0
,
linf_rel_result_value
=
0
;
constexpr
float
eps
=
1e-10
;
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
ref_v
=
ck
::
type_convert
<
float
>
(
ref
.
mData
[
i
]);
float
result_v
=
ck
::
type_convert
<
float
>
(
result
.
mData
[
i
]);
float
diff
=
std
::
abs
(
ref_v
-
result_v
);
float
rel_diff
=
diff
/
std
::
max
(
std
::
abs
(
ref_v
),
eps
);
l1_error
+=
diff
;
if
(
linf_error
<
diff
)
{
linf_error
=
diff
;
linf_ref_value
=
ref_v
;
linf_result_value
=
result_v
;
}
if
(
linf_rel_error
<
rel_diff
)
{
linf_rel_error
=
rel_diff
;
linf_rel_ref_value
=
ref_v
;
linf_rel_result_value
=
result_v
;
}
}
std
::
cout
<<
"Absolute Error L1 Norm (sum of abs diff): "
<<
l1_error
<<
std
::
endl
;
std
::
cout
<<
"Absolute Error L-inf Norm (max abs diff): "
<<
linf_error
<<
", ref "
<<
linf_ref_value
<<
", result "
<<
linf_result_value
<<
std
::
endl
;
std
::
cout
<<
"Relative Error L-inf Norm (max relative abs diff): "
<<
linf_rel_error
<<
", ref "
<<
linf_rel_ref_value
<<
", result "
<<
linf_rel_result_value
<<
std
::
endl
;
return
linf_error
;
}
#endif
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
0 → 100644
View file @
926008bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DsType
=
Tuple
<>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
926008bc
...
@@ -29,9 +29,8 @@ check_err(const std::vector<T>& out,
...
@@ -29,9 +29,8 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -48,9 +47,8 @@ check_err(const std::vector<T>& out,
...
@@ -48,9 +47,8 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
cout
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"out["
<<
i
<<
"] != ref["
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
}
}
...
@@ -72,9 +70,8 @@ check_err(const std::vector<T>& out,
...
@@ -72,9 +70,8 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -94,9 +91,8 @@ check_err(const std::vector<T>& out,
...
@@ -94,9 +91,8 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
cout
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"out["
<<
i
<<
"] != ref["
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
}
}
...
@@ -118,9 +114,8 @@ check_err(const std::vector<T>& out,
...
@@ -118,9 +114,8 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -139,9 +134,8 @@ check_err(const std::vector<T>& out,
...
@@ -139,9 +134,8 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
cout
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"out["
<<
i
<<
"] != ref["
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
}
}
...
@@ -163,9 +157,8 @@ check_err(const std::vector<T>& out,
...
@@ -163,9 +157,8 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
std
::
endl
;
<<
msg
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -185,9 +178,9 @@ check_err(const std::vector<T>& out,
...
@@ -185,9 +178,9 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
cout
<<
"
out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
static_cast
<
int
>
(
out
[
i
])
std
::
cout
<<
msg
<<
"
out["
<<
i
<<
"] != ref["
<<
i
<<
"
!=
"
<<
static_cast
<
int
>
(
ref
[
i
])
<<
std
::
endl
<<
"
]:
"
<<
static_cast
<
int
>
(
out
[
i
])
<<
" != "
<<
static_cast
<
int
>
(
ref
[
i
])
<<
msg
<<
std
::
endl
;
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
}
}
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
View file @
926008bc
...
@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsType
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
using
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
{});
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
View file @
926008bc
...
@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsType
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
using
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
{});
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
926008bc
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
926008bc
This diff is collapsed.
Click to expand it.
profiler/include/profile_batched_gemm_reduce_impl.hpp
View file @
926008bc
...
@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
...
@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf
.
FromDevice
(
d0_g_m_device_result
.
mData
.
data
());
reduce0_device_buf
.
FromDevice
(
d0_g_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
d1_g_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
d1_g_m_device_result
.
mData
.
data
());
float
c_error
=
check_error
(
c_g_m_n_host_result
,
c_g_m_n_device_result
);
bool
c_error
=
float
d0_error
=
check_error
(
d0_g_m_host_result
,
d0_g_m_device_result
);
ck
::
utils
::
check_err
(
c_g_m_n_host_result
.
mData
,
c_g_m_n_device_result
.
mData
);
float
d1_error
=
check_error
(
d1_g_m_host_result
,
d1_g_m_device_result
);
bool
d0_error
=
ck
::
utils
::
check_err
(
d0_g_m_host_result
.
mData
,
d0_g_m_device_result
.
mData
);
pass
=
pass
&&
(
c_error
<
1E-6
);
bool
d1_error
=
pass
=
pass
&&
(
d0_error
<
1E-6
);
ck
::
utils
::
check_err
(
d1_g_m_host_result
.
mData
,
d1_g_m_device_result
.
mData
);
pass
=
pass
&&
(
d1_error
<
1E-6
);
pass
=
pass
&&
(
c_error
==
true
);
pass
=
pass
&&
(
d0_error
==
true
);
pass
=
pass
&&
(
d1_error
==
true
);
if
(
do_log
)
if
(
do_log
)
{
{
...
...
profiler/include/profile_conv_bwd_weight_impl.hpp
View file @
926008bc
...
@@ -250,11 +250,11 @@ bool profile_conv_bwd_weight_impl(int do_verification,
...
@@ -250,11 +250,11 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{
{
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
float
max_error
=
check_error
(
wei_k_c_y_x_host_result
,
wei_k_c_y_x_device_result
);
pass
=
ck
::
utils
::
check_err
(
wei_k_c_y_x_host_result
.
mData
,
wei_k_c_y_x_device_result
.
mData
);
if
(
max_error
>
8
)
if
(
pass
==
false
)
{
{
pass
=
false
;
std
::
cout
<<
"Fail info:"
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Fail info:"
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
...
...
profiler/include/profile_convnd_bwd_data_impl.hpp
View file @
926008bc
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
...
@@ -452,7 +453,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
...
@@ -452,7 +453,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
check_err
or
(
input_host_result
,
input_device_result
);
success
=
ck
::
utils
::
check_err
(
input_host_result
.
mData
,
input_device_result
.
mData
);
if
(
do_log
)
if
(
do_log
)
{
{
...
...
profiler/include/profile_convnd_bwd_weight_impl.hpp
View file @
926008bc
...
@@ -433,21 +433,17 @@ bool profile_convnd_bwd_weight_impl(int do_verification,
...
@@ -433,21 +433,17 @@ bool profile_convnd_bwd_weight_impl(int do_verification,
{
{
wei_device_buf
.
FromDevice
(
weights_device_result
.
mData
.
data
());
wei_device_buf
.
FromDevice
(
weights_device_result
.
mData
.
data
());
float
max_error
=
check_err
or
(
weights_host_result
,
weights_device_result
);
success
=
ck
::
utils
::
check_err
(
weights_host_result
.
mData
,
weights_device_result
.
mData
);
if
(
max_error
>
8
)
if
(
success
==
false
)
{
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
}
}
else
else
{
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
check_error
(
weights_host_result
,
weights_device_result
);
if
(
do_log
)
if
(
do_log
)
{
{
std
::
cout
<<
"in : "
;
std
::
cout
<<
"in : "
;
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
926008bc
...
@@ -7,9 +7,11 @@
...
@@ -7,9 +7,11 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
@@ -17,41 +19,17 @@
...
@@ -17,41 +19,17 @@
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
ck
{
namespace
profiler
{
namespace
profiler
{
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
C
DataType
,
typename
E
DataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
bool
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
bool
time_kernel
,
bool
time_kernel
,
...
@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
)
const
std
::
vector
<
int
>&
StrideCs
)
{
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
@@ -86,7 +67,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -86,7 +67,7 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
C
DataType
>>
c_m_n_device_results
;
std
::
vector
<
Tensor
<
E
DataType
>>
c_m_n_device_results
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
...
@@ -96,7 +77,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -96,7 +77,7 @@ void profile_grouped_gemm_impl(int do_verification,
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n_device_results
.
push_back
(
c_m_n_device_results
.
push_back
(
Tensor
<
C
DataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
E
DataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
@@ -115,7 +96,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -115,7 +96,7 @@ void profile_grouped_gemm_impl(int do_verification,
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
}
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
C
DataType
>
{},
num_thread
);
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
E
DataType
>
{},
num_thread
);
}
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -145,9 +126,9 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -145,9 +126,9 @@ void profile_grouped_gemm_impl(int do_verification,
p_b
.
reserve
(
group_count
);
p_b
.
reserve
(
group_count
);
p_c
.
reserve
(
group_count
);
p_c
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
...
@@ -157,56 +138,34 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -157,56 +138,34 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
C
DataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
sizeof
(
E
DataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
ToDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
ToDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
gemm_
shape
s
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
]});
gemm_
desc
s
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
]
,
{}
});
p_a
.
push_back
(
a_device_buf
[
i
]
->
GetDeviceBuffer
());
p_a
.
push_back
(
a_device_buf
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_device_buf
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_device_buf
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_device_buf
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_device_buf
[
i
]
->
GetDeviceBuffer
());
}
}
// add device GEMM instances
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
BLayout
,
CLayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
is_same
<
CDataType
,
half_t
>::
value
)
DeviceOp
>::
GetInstances
();
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
if
(
gemm
_ptrs
.
size
()
<=
0
)
if
(
op
_ptrs
.
size
()
<=
0
)
{
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
}
...
@@ -216,14 +175,17 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -216,14 +175,17 @@ void profile_grouped_gemm_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
auto
p_ds
=
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
{};
// profile device GEMM instances
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm
_ptrs
)
for
(
auto
&
gemm_ptr
:
op
_ptrs
)
{
{
auto
argument_ptr
=
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
p_a
,
gemm_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_b
,
p_ds
,
p_c
,
p_c
,
gemm_
shape
s
,
gemm_
desc
s
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -242,12 +204,12 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -242,12 +204,12 @@ void profile_grouped_gemm_impl(int do_verification,
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
Ns
[
i
]
+
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
Ns
[
i
]
+
sizeof
(
C
DataType
)
*
Ms
[
i
]
*
Ns
[
i
];
sizeof
(
E
DataType
)
*
Ms
[
i
]
*
Ns
[
i
];
}
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -266,18 +228,18 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -266,18 +228,18 @@ void profile_grouped_gemm_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
Tensor
<
C
DataType
>
c_m_n_host_result
(
Tensor
<
E
DataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{}));
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{}));
using
ReferenceGemmInstance
=
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
C
DataType
,
E
DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
...
@@ -294,7 +256,8 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -294,7 +256,8 @@ void profile_grouped_gemm_impl(int do_verification,
c_element_op
);
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
].
mData
,
c_m_n_host_result
.
mData
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
].
mData
,
c_m_n_host_result
.
mData
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -319,6 +282,8 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -319,6 +282,8 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
}
// namespace profiler
}
// namespace profiler
}
// namespace profiler
}
// namespace profiler
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment