Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
926008bc
Unverified
Commit
926008bc
authored
Jul 21, 2022
by
Chao Liu
Committed by
GitHub
Jul 21, 2022
Browse files
Merge branch 'develop' into batched_gemm_multiD
parents
e9652cbc
d8415a96
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1352 additions
and
663 deletions
+1352
-663
Jenkinsfile
Jenkinsfile
+183
-82
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+51
-41
example/28_grouped_gemm_bias/CMakeLists.txt
example/28_grouped_gemm_bias/CMakeLists.txt
+1
-0
example/28_grouped_gemm_bias/grouped_gemm_bias_xdl_fp16.cpp
example/28_grouped_gemm_bias/grouped_gemm_bias_xdl_fp16.cpp
+278
-0
example/CMakeLists.txt
example/CMakeLists.txt
+2
-1
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+0
-29
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+69
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+430
-258
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+0
-49
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+134
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+17
-24
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
+32
-13
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
+32
-14
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+32
-22
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
+30
-34
profiler/include/profile_batched_gemm_reduce_impl.hpp
profiler/include/profile_batched_gemm_reduce_impl.hpp
+10
-7
profiler/include/profile_conv_bwd_weight_impl.hpp
profiler/include/profile_conv_bwd_weight_impl.hpp
+3
-3
profiler/include/profile_convnd_bwd_data_impl.hpp
profiler/include/profile_convnd_bwd_data_impl.hpp
+2
-1
profiler/include/profile_convnd_bwd_weight_impl.hpp
profiler/include/profile_convnd_bwd_weight_impl.hpp
+2
-6
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+44
-79
No files found.
Jenkinsfile
View file @
926008bc
...
...
@@ -11,6 +11,12 @@ def show_node_info() {
"""
}
def
runShell
(
String
command
){
def
responseCode
=
sh
returnStatus:
true
,
script:
"${command} &> tmp.txt"
def
output
=
readFile
(
file:
"tmp.txt"
)
return
(
output
!=
""
)
}
def
cmake_build
(
Map
conf
=[:]){
def
compiler
=
conf
.
get
(
"compiler"
,
"/opt/rocm/bin/hipcc"
)
...
...
@@ -60,7 +66,7 @@ def cmake_build(Map conf=[:]){
"""
def
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
// reduce parallelism when compiling, clang uses too much memory
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j\$(( \$(nproc) /
1
)) ${config_targets}"
)
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j\$(( \$(nproc) /
2
)) ${config_targets}"
)
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
def
cmd
=
conf
.
get
(
"cmd"
,
"""
...
...
@@ -113,7 +119,14 @@ def buildHipClangJob(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
...
...
@@ -125,7 +138,14 @@ def buildHipClangJob(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
" --no-cache ."
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo |tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
...
...
@@ -133,7 +153,14 @@ def buildHipClangJob(Map conf=[:]){
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
5
,
unit:
'HOURS'
)
{
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
cmake_build
(
conf
)
}
}
...
...
@@ -145,7 +172,6 @@ def reboot(){
build
job:
'reboot-slaves'
,
propagate:
false
,
parameters:
[
string
(
name:
'server'
,
value:
"${env.NODE_NAME}"
),]
}
def
buildHipClangJobAndReboot
(
Map
conf
=[:]){
try
{
buildHipClangJob
(
conf
)
...
...
@@ -162,7 +188,6 @@ def buildHipClangJobAndReboot(Map conf=[:]){
}
}
def
runCKProfiler
(
Map
conf
=[:]){
show_node_info
()
...
...
@@ -189,7 +214,6 @@ def runCKProfiler(Map conf=[:]){
}
def
variant
=
env
.
STAGE_NAME
def
retimage
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
...
...
@@ -197,7 +221,14 @@ def runCKProfiler(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
...
...
@@ -209,89 +240,69 @@ def runCKProfiler(Map conf=[:]){
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
" --no-cache ."
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
){
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
if
(
runShell
(
'grep -n "Number of devices:.*. 0" clinfo.log'
)
){
echo
"GPU not found"
throw
e
}
else
{
echo
"GPU is OK"
}
}
}
}
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
5
,
unit:
'HOURS'
)
timeout
(
time:
24
,
unit:
'HOURS'
)
{
cmake_build
(
conf
)
dir
(
"script"
){
//run gemm performance tests
def
gemm_log
=
"perf_gemm_${gpu_arch}.log"
sh
"rm -f ${gemm_log}"
sh
"echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
sh
"echo Node name: ${NODE_NAME} >> ${gemm_log}"
sh
"echo GPU_arch name: ${gpu_arch} >> ${gemm_log}"
sh
"rocminfo | grep 'Compute Unit:' >> ${gemm_log} "
sh
"hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
if
(
params
.
RUN_FULL_QA
){
def
qa_log
=
"qa_${gpu_arch}.log"
if
(
params
.
USE_9110
){
sh
"echo Environment type: CI_9110 >> ${gemm_log
}"
sh
"./run_full_performance_tests.sh 1 QA_9110 ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME
}"
}
else
{
sh
"echo Environment type: CI_release >> ${gemm_log}"
}
sh
"/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
sh
"./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
sh
"./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
archiveArtifacts
"${gemm_log}"
sh
"python3 process_perf_data.py ${gemm_log} "
//run resnet50 test
def
resnet256_log
=
"perf_resnet50_N256_${gpu_arch}.log"
sh
"rm -f ${resnet256_log}"
sh
"echo Branch name: ${env.BRANCH_NAME} > ${resnet256_log}"
sh
"echo Node name: ${NODE_NAME} >> ${resnet256_log}"
sh
"echo GPU_arch name: ${gpu_arch} >> ${resnet256_log}"
sh
"rocminfo | grep 'Compute Unit:' >> ${resnet256_log} "
sh
"hipcc --version | grep -e 'HIP version' >> ${resnet256_log}"
if
(
params
.
USE_9110
){
sh
"echo Environment type: CI_9110 >> ${resnet256_log}"
sh
"./run_full_performance_tests.sh 1 QA_release ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME}"
}
archiveArtifacts
"perf_gemm_${gpu_arch}.log"
archiveArtifacts
"perf_resnet50_N256_${gpu_arch}.log"
archiveArtifacts
"perf_resnet50_N4_${gpu_arch}.log"
archiveArtifacts
"perf_bathced_gemm_${gpu_arch}.log"
archiveArtifacts
"perf_grouped_gemm_${gpu_arch}.log"
archiveArtifacts
"perf_fwd_conv_${gpu_arch}.log"
archiveArtifacts
"perf_bwd_conv_${gpu_arch}.log"
archiveArtifacts
"perf_fusion_${gpu_arch}.log"
archiveArtifacts
"perf_reduction_${gpu_arch}.log"
// stash perf files to master
stash
name:
"perf_gemm_${gpu_arch}.log"
stash
name:
"perf_resnet50_N256_${gpu_arch}.log"
stash
name:
"perf_resnet50_N4_${gpu_arch}.log"
stash
name:
"perf_bathced_gemm_${gpu_arch}.log"
stash
name:
"perf_grouped_gemm_${gpu_arch}.log"
stash
name:
"perf_fwd_conv_${gpu_arch}.log"
stash
name:
"perf_bwd_conv_${gpu_arch}.log"
stash
name:
"perf_fusion_${gpu_arch}.log"
stash
name:
"perf_reduction_${gpu_arch}.log"
//we will process results on the master node
}
else
{
sh
"echo Environment type: CI_release >> ${resnet256_log}"
}
sh
"/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet256_log}"
//first run tests with N=256
sh
"./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet256_log}"
archiveArtifacts
"${resnet256_log}"
sh
"python3 process_perf_data.py ${resnet256_log} "
//then run with N=4
def
resnet4_log
=
"perf_resnet50_N4_${gpu_arch}.log"
sh
"rm -f ${resnet4_log}"
sh
"echo Branch name: ${env.BRANCH_NAME} > ${resnet4_log}"
sh
"echo Node name: ${NODE_NAME} >> ${resnet4_log}"
sh
"echo GPU_arch name: ${gpu_arch} >> ${resnet4_log}"
sh
"rocminfo | grep 'Compute Unit:' >> ${resnet4_log} "
sh
"hipcc --version | grep -e 'HIP version' >> ${resnet4_log}"
if
(
params
.
USE_9110
){
sh
"echo Environment type: CI_9110 >> ${resnet4_log
}"
sh
"./run_performance_tests.sh 0 CI_9110 ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME
}"
}
else
{
sh
"echo Environment type: CI_release >> ${resnet4_log
}"
sh
"./run_performance_tests.sh 0 CI_release ${gpu_arch} ${env.BRANCH_NAME} ${NODE_NAME
}"
}
sh
"/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet4_log}"
sh
"./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet4_log}"
archiveArtifacts
"${resnet4_log}"
sh
"python3 process_perf_data.py ${resnet4_log} "
archiveArtifacts
"perf_gemm_${gpu_arch}.log"
archiveArtifacts
"perf_resnet50_N256_${gpu_arch}.log"
archiveArtifacts
"perf_resnet50_N4_${gpu_arch}.log"
// stash perf files to master
stash
name:
"perf_gemm_${gpu_arch}.log"
stash
name:
"perf_resnet50_N256_${gpu_arch}.log"
stash
name:
"perf_resnet50_N4_${gpu_arch}.log"
//we will process the results on the master node
}
}
}
}
...
...
@@ -299,7 +310,6 @@ def runCKProfiler(Map conf=[:]){
return
retimage
}
def
runPerfTest
(
Map
conf
=[:]){
try
{
runCKProfiler
(
conf
)
...
...
@@ -316,8 +326,76 @@ def runPerfTest(Map conf=[:]){
}
}
def
process_results
(
Map
conf
=[:]){
env
.
HSA_ENABLE_SDMA
=
0
checkout
scm
def
image
=
"composable_kernels"
def
prefixpath
=
"/opt/rocm"
def
gpu_arch
=
conf
.
get
(
"gpu_arch"
,
"gfx908"
)
// Jenkins is complaining about the render group
def
dockerOpts
=
"--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1"
}
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' --build-arg compiler_version='release' "
def
variant
=
env
.
STAGE_NAME
def
retimage
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
try
{
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
}
catch
(
org
.
jenkinsci
.
plugins
.
workflow
.
steps
.
FlowInterruptedException
e
){
echo
"The job was cancelled or aborted"
throw
e
}
}
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
1
,
unit:
'HOURS'
){
try
{
dir
(
"script"
){
if
(
params
.
RUN_FULL_QA
){
// unstash perf files to master
unstash
"perf_gemm_${gpu_arch}.log"
unstash
"perf_resnet50_N256_${gpu_arch}.log"
unstash
"perf_resnet50_N4_${gpu_arch}.log"
unstash
"perf_bathced_gemm_${gpu_arch}.log"
unstash
"perf_grouped_gemm_${gpu_arch}.log"
unstash
"perf_fwd_conv_${gpu_arch}.log"
unstash
"perf_bwd_conv_${gpu_arch}.log"
unstash
"perf_fusion_${gpu_arch}.log"
unstash
"perf_reduction_${gpu_arch}.log"
sh
"./process_qa_data.sh ${gpu_arch}"
}
else
{
// unstash perf files to master
unstash
"perf_gemm_${gpu_arch}.log"
unstash
"perf_resnet50_N256_${gpu_arch}.log"
unstash
"perf_resnet50_N4_${gpu_arch}.log"
sh
"./process_perf_data.sh ${gpu_arch}"
}
}
}
catch
(
e
){
echo
"throwing error exception while processing performance test results"
echo
'Exception occurred: '
+
e
.
toString
()
throw
e
}
}
}
}
//launch develop branch daily at 23:00 in FULL_QA mode
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;USE_9110=true'''
:
""
pipeline
{
agent
none
triggers
{
cron
(
CRON_SETTINGS
)
}
options
{
parallelsAlwaysFailFast
()
}
...
...
@@ -325,7 +403,11 @@ pipeline {
booleanParam
(
name:
"USE_9110"
,
defaultValue:
true
,
description:
""
)
description:
"Select compiler version: 9110 (default) or release"
)
booleanParam
(
name:
"RUN_FULL_QA"
,
defaultValue:
false
,
description:
"Select whether to run small set of performance tests (default) or full QA"
)
}
environment
{
dbuser
=
"${dbuser}"
...
...
@@ -438,6 +520,25 @@ pipeline {
}
}
}
stage
(
"Process Performance Test Results"
)
{
parallel
{
stage
(
"Process results for gfx908"
){
agent
{
label
'mici'
}
steps
{
process_results
(
gpu_arch:
"gfx908"
)
}
}
stage
(
"Process results for gfx90a"
){
agent
{
label
'mici'
}
steps
{
process_results
(
gpu_arch:
"gfx90a"
)
}
}
}
}
/* enable after the cmake file supports packaging
stage("Packages") {
when {
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
926008bc
...
...
@@ -29,34 +29,39 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
C
Layout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
E
Layout
=
Row
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
C
DE
ElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
>
;
// clang-format off
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -81,11 +86,11 @@ int main(int argc, char* argv[])
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
...
...
@@ -93,7 +98,11 @@ int main(int argc, char* argv[])
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
int
stride_A
=
K
;
int
stride_B
=
K
;
int
stride_C
=
N
;
gemm_descs
.
push_back
({
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
{}});
}
auto
f_host_tensor_descriptor
=
...
...
@@ -111,10 +120,9 @@ int main(int argc, char* argv[])
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
C
DataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
C
DataType
>>
c_device_tensors
;
std
::
vector
<
Tensor
<
E
DataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
E
DataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
...
...
@@ -131,25 +139,25 @@ int main(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
S
tride
A
,
ALayout
{})));
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
K
_
,
gemm_
desc
s
[
i
].
s
tride
_A_
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
B
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
C
DataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
C
,
C
Layout
{})));
c_device_tensors
.
push_back
(
Tensor
<
C
DataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
S
tride
C
,
C
Layout
{})));
gemm_
desc
s
[
i
].
K
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_B_
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
E
DataType
>
(
f_host_tensor_descriptor
(
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_C_
,
E
Layout
{})));
c_device_tensors
.
push_back
(
Tensor
<
E
DataType
>
(
f_host_tensor_descriptor
(
gemm_
desc
s
[
i
].
M
_
,
gemm_
desc
s
[
i
].
N
_
,
gemm_
desc
s
[
i
].
s
tride
_C_
,
E
Layout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_
shape
s
[
i
].
M
*
gemm_
shape
s
[
i
].
K
*
gemm_
shape
s
[
i
].
N
;
flop
+=
std
::
size_t
(
2
)
*
gemm_
desc
s
[
i
].
M
_
*
gemm_
desc
s
[
i
].
K
_
*
gemm_
desc
s
[
i
].
N
_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
C
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
sizeof
(
E
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
init_method
)
{
...
...
@@ -168,14 +176,14 @@ int main(int argc, char* argv[])
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpace
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
C
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
sizeof
(
E
DataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
...
...
@@ -187,14 +195,16 @@ int main(int argc, char* argv[])
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
C
DE
ElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_Ds
=
{};
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_Ds
,
p_c
,
gemm_
desc
s
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
...
@@ -219,7 +229,7 @@ int main(int argc, char* argv[])
bool
pass
=
true
;
if
(
do_verification
)
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
...
example/28_grouped_gemm_bias/CMakeLists.txt
0 → 100644
View file @
926008bc
add_example_executable
(
example_grouped_gemm_bias_xdl_fp16 grouped_gemm_bias_xdl_fp16.cpp
)
example/28_grouped_gemm_bias/grouped_gemm_bias_xdl_fp16.cpp
0 → 100644
View file @
926008bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
D0DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmXdl
// clang-format off
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
exit
(
0
);
}
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_ds
;
std
::
vector
<
void
*>
p_c
;
gemm_descs
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
+
256
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
int
stride_A
=
K
;
int
stride_B
=
K
;
int
stride_C
=
N
;
std
::
vector
<
ck
::
index_t
>
stride_Ds
=
{
0
};
gemm_descs
.
push_back
({
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
stride_Ds
});
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
D0DataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
e_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
e_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
d0_tensors
.
reserve
(
group_count
);
e_host_tensors
.
reserve
(
group_count
);
e_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
d0_tensors_device
,
e_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
d0_tensors_device
.
reserve
(
group_count
);
e_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
K_
,
gemm_descs
[
i
].
stride_A_
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_B_
,
BLayout
{})));
d0_tensors
.
push_back
(
Tensor
<
D0DataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_Ds_
[
0
],
ELayout
{})));
e_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_C_
,
ELayout
{})));
e_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M_
,
gemm_descs
[
i
].
N_
,
gemm_descs
[
i
].
stride_C_
,
ELayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
e_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M_
*
gemm_descs
[
i
].
K_
*
gemm_descs
[
i
].
N_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpace
()));
d0_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
D0DataType
)
*
d0_tensors
[
i
].
mDesc
.
GetElementSpace
()));
e_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
d0_tensors_device
[
i
]
->
ToDevice
(
d0_tensors
[
i
].
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_ds
.
push_back
({
d0_tensors_device
[
i
]
->
GetDeviceBuffer
()});
p_c
.
push_back
(
e_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_ds
,
p_c
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
e_tensors_device
[
i
]
->
FromDevice
(
e_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
e_host_tensors
[
i
],
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
gemm_descs
[
i
].
M_
;
++
m
)
{
for
(
int
n
=
0
;
n
<
gemm_descs
[
i
].
N_
;
++
n
)
{
cde_element_op
(
e_host_tensors
[
i
](
m
,
n
),
e_host_tensors
[
i
](
m
,
n
),
d0_tensors
[
i
](
m
,
n
));
}
}
pass
&=
ck
::
utils
::
check_err
(
e_device_tensors
[
i
].
mData
,
e_host_tensors
[
i
].
mData
);
}
}
return
pass
?
0
:
1
;
}
example/CMakeLists.txt
View file @
926008bc
...
...
@@ -46,4 +46,5 @@ add_subdirectory(24_batched_gemm_c_permute)
add_subdirectory
(
25_gemm_bias_c_permute
)
add_subdirectory
(
26_contraction
)
add_subdirectory
(
27_layernorm
)
add_subdirectory
(
28_grouped_gemm_bias
)
add_subdirectory
(
29_batched_gemm_multi_d
)
\ No newline at end of file
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
926008bc
...
...
@@ -12,12 +12,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
struct
GemmShape
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
...
...
@@ -65,29 +59,6 @@ using DeviceGemmPtr = std::unique_ptr<DeviceGemm<ALayout,
BElementwiseOperation
,
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
0 → 100644
View file @
926008bc
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
GemmDesc
{
ck
::
index_t
M_
,
N_
,
K_
;
ck
::
index_t
stride_A_
,
stride_B_
,
stride_C_
;
std
::
vector
<
ck
::
index_t
>
stride_Ds_
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_ds
,
std
::
vector
<
void
*>&
p_e
,
std
::
vector
<
GemmDesc
>&
gemm_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
926008bc
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
...
...
@@ -10,9 +11,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_
xdlops_v2r3
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_
multiple_d_xdl_cshuffle
.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
...
...
@@ -21,22 +22,20 @@ namespace tensor_operation {
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl
ops_v2r3
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
const
C
DE
ElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -65,42 +64,48 @@ __global__ void
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
gemm_desc_ptr
[
group_id
].
a_ptr_
,
gemm_desc_ptr
[
group_id
].
b_ptr_
,
gemm_desc_ptr
[
group_id
].
ds_ptr_
,
gemm_desc_ptr
[
group_id
].
e_ptr_
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_ptr
[
group_id
].
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif
// end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
...
...
@@ -119,155 +124,319 @@ template <typename ADataType,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
index_t
MaxGroupCount
=
10
>
struct
DeviceGroupedGemmXdl
:
public
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGroupedGemmXdl
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
::
value
)
else
if
constexpr
(
is_same
_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
// pad both M and K
assert
(
K
%
AK1
==
0
);
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
assert
(
K
%
K1
==
0
);
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
const
auto
b_grid_desc_k_n
=
[
&
]()
{
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
// pad both N and K
assert
(
K
%
BK1
==
0
);
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
Make
C
GridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
Stride
C
)
static
auto
Make
E
GridDescriptor_M_N
(
index_t
M
Raw
,
index_t
N
Raw
,
index_t
Stride
E
)
{
const
auto
c_grid_desc_m
_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
C
Layout
>::
value
)
const
auto
c_grid_desc_m
raw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DE
Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
C
Layout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
DE
Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_m
_n
,
make_tuple
(
make_right_pad_transform
(
M
,
Pad
M
),
make_
right_pad
_transform
(
N
,
PadN
)),
c_grid_desc_m
raw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
Raw
,
M
Pad
),
make_
pass_through
_transform
(
N
Raw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_m
_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_
pass_through
_transform
(
N
)),
c_grid_desc_m
raw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
Raw
),
make_
right_pad
_transform
(
N
Raw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
C
GridDesc_M_N
=
decltype
(
Make
C
GridDescriptor_M_N
(
1
,
1
,
1
));
using
AGridDesc_
A
K0_M_
A
K1
=
decltype
(
MakeAGridDescriptor_
A
K0_M_
A
K1
(
1
,
1
,
1
));
using
BGridDesc_
B
K0_N_
B
K1
=
decltype
(
MakeBGridDescriptor_
B
K0_N_
B
K1
(
1
,
1
,
1
));
using
E
GridDesc_M_N
=
decltype
(
Make
E
GridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
GemmAccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
EGridDesc_M_N
,
NumPrefetch
,
// NumGemmKPrefetchStage
BlockSize
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -286,30 +455,28 @@ struct DeviceGroupedGemmXdl
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
struct
GroupedGemmBlock2
C
TileMap
struct
GroupedGemmBlock2
E
TileMap
{
using
UnderlyingBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2
C
TileMap
;
using
UnderlyingBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2
E
TileMap
;
static_assert
(
std
::
is_same
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2
C
TileMap
(
C
GridDesc_M_N
{}
,
1
,
1
)),
typename
GridwiseGemm
::
DefaultBlock2
C
TileMap
>::
value
,
std
::
is_same
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2
E
TileMap
(
E
GridDesc_M_N
{})),
typename
GridwiseGemm
::
DefaultBlock2
E
TileMap
>::
value
,
"Wrong! Should be the same type name"
);
GroupedGemmBlock2
C
TileMap
()
GroupedGemmBlock2
E
TileMap
()
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2
C
TileMap
(
C
GridDesc_M_N
{}
,
1
,
1
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2
E
TileMap
(
E
GridDesc_M_N
{});
BlockStart_
=
-
1
;
}
GroupedGemmBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
,
ck
::
index_t
BlockStart
)
GroupedGemmBlock2ETileMap
(
const
EGridDesc_M_N
&
c_grid_desc_m_n
,
ck
::
index_t
BlockStart
)
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2
C
TileMap
(
c_grid_desc_m_n
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2
E
TileMap
(
c_grid_desc_m_n
);
BlockStart_
=
BlockStart
;
}
...
...
@@ -327,29 +494,35 @@ struct DeviceGroupedGemmXdl
return
block_2_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
__host__
bool
CheckValidity
(
const
C
GridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
bool
CheckValidity
(
const
E
GridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_2_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
typename
GridwiseGemm
::
DefaultBlock2
C
TileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2
E
TileMap
block_2_ctile_map_
;
ck
::
index_t
BlockStart_
;
};
struct
Gemm
Desc
KernelArg
struct
Gemm
BiasTrans
KernelArg
{
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
C
GridDesc_M_N
c
_grid_desc_m_n_
;
AGridDesc_
A
K0_M_
A
K1
a_grid_desc_k0_m_k1_
;
BGridDesc_
B
K0_N_
B
K1
b_grid_desc_k0_n_k1_
;
E
GridDesc_M_N
e
_grid_desc_m_n_
;
typename
GridwiseGemm
::
C
GridDesc
_M0_N0_M1_N1_M2_M3_M4_N2
c
_grid_desc_m
0_n0_m1_n1_m2_m3_m4_n2
_
;
typename
GridwiseGemm
::
E
GridDesc
riptor_MBlock_MPerBlock_NBlock_NPerBlock
e
_grid_desc_m
block_mperblock_nblock_nperblock
_
;
GroupedGemmBlock2CTileMap
grouped_gemm_block_2_ctile_map_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
const
ADataType
*
a_ptr
;
const
BDataType
*
b_ptr
;
CDataType
*
c_ptr
;
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
const
ADataType
*
a_ptr_
;
const
BDataType
*
b_ptr_
;
typename
GridwiseGemm
::
DsGridPointer
ds_ptr_
;
EDataType
*
e_ptr_
;
ck
::
index_t
BlockStart_
,
BlockEnd_
;
};
...
...
@@ -357,82 +530,99 @@ struct DeviceGroupedGemmXdl
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
index_t
M01
,
index_t
N01
,
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
CDEElementwiseOperation
c_element_op
)
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
grid_size_
=
0
;
p_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_
a
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_
b
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_
c
.
size
())))
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_
As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_
Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_
Es
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ !=
P_a
/b/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ !=
p_As
/b/c.size"
);
}
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
const
index_t
M
=
gemm_
shape
s
[
i
].
M
;
const
index_t
N
=
gemm_
shape
s
[
i
].
N
;
const
index_t
K
=
gemm_
shape
s
[
i
].
K
;
const
index_t
M
=
gemm_
desc
s
[
i
].
M
_
;
const
index_t
N
=
gemm_
desc
s
[
i
].
N
_
;
const
index_t
K
=
gemm_
desc
s
[
i
].
K
_
;
const
index_t
StrideA
=
gemm_
shape
s
[
i
].
S
tride
A
;
const
index_t
StrideB
=
gemm_
shape
s
[
i
].
S
tride
B
;
const
index_t
StrideC
=
gemm_
shape
s
[
i
].
S
tride
C
;
const
index_t
StrideA
=
gemm_
desc
s
[
i
].
s
tride
_A_
;
const
index_t
StrideB
=
gemm_
desc
s
[
i
].
s
tride
_B_
;
const
index_t
StrideC
=
gemm_
desc
s
[
i
].
s
tride
_C_
;
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_
A
K0_M_
A
K1
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_k0_n_k1_
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
const
auto
c_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
StrideB
);
const
auto
e_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeEGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
index_t
grid_size_grp
=
GroupedGemmBlock2
C
TileMap
(
c
_grid_desc_m_n_
,
M01
,
N01
,
0
)
.
block_2_ctile_map_
.
CalculateGridSize
(
c
_grid_desc_m_n_
);
GroupedGemmBlock2
E
TileMap
(
e
_grid_desc_m_n_
,
0
)
.
block_2_ctile_map_
.
CalculateGridSize
(
e
_grid_desc_m_n_
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
const
auto
grouped_gemm_
block_2_ctile_map_
=
GroupedGemmBlock2
C
TileMap
(
c
_grid_desc_m_n_
,
M01
,
N01
,
BlockStart
);
const
auto
block_2_ctile_map_
=
GroupedGemmBlock2
E
TileMap
(
e
_grid_desc_m_n_
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c
_grid_desc_m_n_
,
grouped_gemm_
block_2_ctile_map_
))
e
_grid_desc_m_n_
,
block_2_ctile_map_
))
{
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
auto
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of
// different
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
{};
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsDataType
>>
;
p_ds_grid_
(
j
)
=
static_cast
<
const
DDataType
*>
(
p_Ds
[
i
][
j
]);
const
auto
d_grid_desc_m_n
=
DeviceGroupedGemmXdl
::
MakeEGridDescriptor_M_N
(
M
,
N
,
gemm_descs
[
i
].
stride_Ds_
[
j
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
j
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
});
gemm_desc_kernel_arg_
.
push_back
(
Gemm
Desc
KernelArg
{
a_grid_desc_k0_m_k1_
,
Gemm
BiasTrans
KernelArg
{
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
grouped_gemm_block_2_ctile_map_
,
static_cast
<
const
ADataType
*>
(
p_a
[
i
]),
static_cast
<
const
BDataType
*>
(
p_b
[
i
]),
static_cast
<
CDataType
*>
(
p_c
[
i
]),
e_grid_desc_m_n_
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map_
,
static_cast
<
const
ADataType
*>
(
p_As
[
i
]),
static_cast
<
const
BDataType
*>
(
p_Bs
[
i
]),
p_ds_grid_
,
static_cast
<
EDataType
*>
(
p_Es
[
i
]),
BlockStart
,
BlockEnd
});
}
...
...
@@ -440,14 +630,12 @@ struct DeviceGroupedGemmXdl
}
// private:
index_t
M01_
;
index_t
N01_
;
index_t
group_count_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
C
DE
ElementwiseOperation
c_element_op_
;
std
::
vector
<
Gemm
Desc
KernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
Gemm
BiasTrans
KernelArg
>
gemm_desc_kernel_arg_
;
index_t
grid_size_
;
};
...
...
@@ -473,16 +661,15 @@ struct DeviceGroupedGemmXdl
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.
c
_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
c
_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
c
_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
std
::
cout
<<
", arg.
e
_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e
_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e
_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
c
_grid_desc_m_n_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
grouped_gemm_
block_2_ctile_map_
))
arg
.
gemm_desc_kernel_arg_
[
i
].
e
_grid_desc_m_n_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
...
...
@@ -500,25 +687,21 @@ struct DeviceGroupedGemmXdl
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
Gemm
Desc
KernelArg
),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
Gemm
BiasTrans
KernelArg
),
hipMemcpyHostToDevice
));
float
ave_time
=
0
;
if
(
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
GemmDescKernelArg
,
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
GemmBiasTransKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
C
ElementwiseOperation
,
true
>
;
CDE
ElementwiseOperation
,
has_main_k_block_loop_
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
...
...
@@ -528,30 +711,15 @@ struct DeviceGroupedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
};
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
GemmDescKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
...
...
@@ -585,31 +753,34 @@ struct DeviceGroupedGemmXdl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
C
DE
ElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
CDEElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_
a
,
p_
b
,
p_
c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
p_
As
,
p_
Bs
,
p_
Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
...
...
@@ -629,8 +800,9 @@ struct DeviceGroupedGemmXdl
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
...
...
@@ -643,7 +815,7 @@ struct DeviceGroupedGemmXdl
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
Gemm
Desc
KernelArg
);
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
Gemm
BiasTrans
KernelArg
);
}
};
...
...
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
926008bc
...
...
@@ -381,52 +381,3 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
}
#if 1
// FIXME: remove
template
<
typename
T
>
float
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
float
l1_error
=
0
;
float
linf_error
=
-
1
;
float
linf_rel_error
=
-
1
;
float
linf_ref_value
=
0
,
linf_result_value
=
0
;
float
linf_rel_ref_value
=
0
,
linf_rel_result_value
=
0
;
constexpr
float
eps
=
1e-10
;
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
ref_v
=
ck
::
type_convert
<
float
>
(
ref
.
mData
[
i
]);
float
result_v
=
ck
::
type_convert
<
float
>
(
result
.
mData
[
i
]);
float
diff
=
std
::
abs
(
ref_v
-
result_v
);
float
rel_diff
=
diff
/
std
::
max
(
std
::
abs
(
ref_v
),
eps
);
l1_error
+=
diff
;
if
(
linf_error
<
diff
)
{
linf_error
=
diff
;
linf_ref_value
=
ref_v
;
linf_result_value
=
result_v
;
}
if
(
linf_rel_error
<
rel_diff
)
{
linf_rel_error
=
rel_diff
;
linf_rel_ref_value
=
ref_v
;
linf_rel_result_value
=
result_v
;
}
}
std
::
cout
<<
"Absolute Error L1 Norm (sum of abs diff): "
<<
l1_error
<<
std
::
endl
;
std
::
cout
<<
"Absolute Error L-inf Norm (max abs diff): "
<<
linf_error
<<
", ref "
<<
linf_ref_value
<<
", result "
<<
linf_result_value
<<
std
::
endl
;
std
::
cout
<<
"Relative Error L-inf Norm (max relative abs diff): "
<<
linf_rel_error
<<
", ref "
<<
linf_rel_ref_value
<<
", result "
<<
linf_rel_result_value
<<
std
::
endl
;
return
linf_error
;
}
#endif
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
0 → 100644
View file @
926008bc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DsType
=
Tuple
<>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
926008bc
...
...
@@ -29,9 +29,8 @@ check_err(const std::vector<T>& out,
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -48,9 +47,8 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cout
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
}
res
=
false
;
}
...
...
@@ -72,9 +70,8 @@ check_err(const std::vector<T>& out,
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -94,9 +91,8 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cout
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
...
...
@@ -118,9 +114,8 @@ check_err(const std::vector<T>& out,
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -139,9 +134,8 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cout
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
...
...
@@ -163,9 +157,8 @@ check_err(const std::vector<T>& out,
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cout
<<
"out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -185,9 +178,9 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cout
<<
"
out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
static_cast
<
int
>
(
out
[
i
])
<<
"
!=
"
<<
static_cast
<
int
>
(
ref
[
i
])
<<
std
::
endl
<<
msg
<<
std
::
endl
;
std
::
cout
<<
msg
<<
"
out["
<<
i
<<
"] != ref["
<<
i
<<
"
]:
"
<<
static_cast
<
int
>
(
out
[
i
])
<<
" != "
<<
static_cast
<
int
>
(
ref
[
i
])
<<
std
::
endl
;
}
res
=
false
;
}
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
View file @
926008bc
...
...
@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsType
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
//##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
View file @
926008bc
...
...
@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsType
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
//##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
926008bc
...
...
@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsType
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -30,32 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
256
,
4
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
32
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
32
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
//##################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsType| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
926008bc
...
...
@@ -23,53 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsType
=
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
// irregular tile size
using
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
8
,
8
,
16
,
16
,
2
,
9
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
8
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
>
,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
144
,
4
,
8
,
16
,
16
,
2
,
9
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
>
//##################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdl
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
DsType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
{});
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances
{});
}
}
// namespace instance
...
...
profiler/include/profile_batched_gemm_reduce_impl.hpp
View file @
926008bc
...
...
@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf
.
FromDevice
(
d0_g_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
d1_g_m_device_result
.
mData
.
data
());
float
c_error
=
check_error
(
c_g_m_n_host_result
,
c_g_m_n_device_result
);
float
d0_error
=
check_error
(
d0_g_m_host_result
,
d0_g_m_device_result
);
float
d1_error
=
check_error
(
d1_g_m_host_result
,
d1_g_m_device_result
);
pass
=
pass
&&
(
c_error
<
1E-6
);
pass
=
pass
&&
(
d0_error
<
1E-6
);
pass
=
pass
&&
(
d1_error
<
1E-6
);
bool
c_error
=
ck
::
utils
::
check_err
(
c_g_m_n_host_result
.
mData
,
c_g_m_n_device_result
.
mData
);
bool
d0_error
=
ck
::
utils
::
check_err
(
d0_g_m_host_result
.
mData
,
d0_g_m_device_result
.
mData
);
bool
d1_error
=
ck
::
utils
::
check_err
(
d1_g_m_host_result
.
mData
,
d1_g_m_device_result
.
mData
);
pass
=
pass
&&
(
c_error
==
true
);
pass
=
pass
&&
(
d0_error
==
true
);
pass
=
pass
&&
(
d1_error
==
true
);
if
(
do_log
)
{
...
...
profiler/include/profile_conv_bwd_weight_impl.hpp
View file @
926008bc
...
...
@@ -250,11 +250,11 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
float
max_error
=
check_error
(
wei_k_c_y_x_host_result
,
wei_k_c_y_x_device_result
);
pass
=
ck
::
utils
::
check_err
(
wei_k_c_y_x_host_result
.
mData
,
wei_k_c_y_x_device_result
.
mData
);
if
(
max_error
>
8
)
if
(
pass
==
false
)
{
pass
=
false
;
std
::
cout
<<
"Fail info:"
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
...
...
profiler/include/profile_convnd_bwd_data_impl.hpp
View file @
926008bc
...
...
@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
...
...
@@ -452,7 +453,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
check_err
or
(
input_host_result
,
input_device_result
);
success
=
ck
::
utils
::
check_err
(
input_host_result
.
mData
,
input_device_result
.
mData
);
if
(
do_log
)
{
...
...
profiler/include/profile_convnd_bwd_weight_impl.hpp
View file @
926008bc
...
...
@@ -433,21 +433,17 @@ bool profile_convnd_bwd_weight_impl(int do_verification,
{
wei_device_buf
.
FromDevice
(
weights_device_result
.
mData
.
data
());
float
max_error
=
check_err
or
(
weights_host_result
,
weights_device_result
);
success
=
ck
::
utils
::
check_err
(
weights_host_result
.
mData
,
weights_device_result
.
mData
);
if
(
max_error
>
8
)
if
(
success
==
false
)
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
}
else
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
check_error
(
weights_host_result
,
weights_device_result
);
if
(
do_log
)
{
std
::
cout
<<
"in : "
;
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
926008bc
...
...
@@ -7,9 +7,11 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
@@ -17,41 +19,17 @@
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
C
DataType
,
typename
E
DataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
bool
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
...
...
@@ -62,6 +40,9 @@ void profile_grouped_gemm_impl(int do_verification,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
)
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
...
@@ -86,7 +67,7 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
C
DataType
>>
c_m_n_device_results
;
std
::
vector
<
Tensor
<
E
DataType
>>
c_m_n_device_results
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
...
...
@@ -96,7 +77,7 @@ void profile_grouped_gemm_impl(int do_verification,
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n_device_results
.
push_back
(
Tensor
<
C
DataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
E
DataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
@@ -115,7 +96,7 @@ void profile_grouped_gemm_impl(int do_verification,
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
C
DataType
>
{},
num_thread
);
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
E
DataType
>
{},
num_thread
);
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -145,9 +126,9 @@ void profile_grouped_gemm_impl(int do_verification,
p_b
.
reserve
(
group_count
);
p_c
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
...
...
@@ -157,56 +138,34 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
C
DataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
sizeof
(
E
DataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_device_buf
[
i
]
->
ToDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
gemm_
shape
s
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
]});
gemm_
desc
s
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
]
,
{}
});
p_a
.
push_back
(
a_device_buf
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_device_buf
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_device_buf
[
i
]
->
GetDeviceBuffer
());
}
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
if
(
gemm
_ptrs
.
size
()
<=
0
)
if
(
op
_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device GEMM instance found"
);
}
...
...
@@ -216,14 +175,17 @@ void profile_grouped_gemm_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
auto
p_ds
=
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
{};
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm
_ptrs
)
for
(
auto
&
gemm_ptr
:
op
_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_c
,
gemm_
shape
s
,
gemm_
desc
s
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
...
@@ -242,12 +204,12 @@ void profile_grouped_gemm_impl(int do_verification,
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
Ns
[
i
]
+
sizeof
(
C
DataType
)
*
Ms
[
i
]
*
Ns
[
i
];
sizeof
(
E
DataType
)
*
Ms
[
i
]
*
Ns
[
i
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
@@ -266,18 +228,18 @@ void profile_grouped_gemm_impl(int do_verification,
if
(
do_verification
)
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
c_device_buf
[
i
]
->
FromDevice
(
c_m_n_device_results
[
i
].
mData
.
data
());
Tensor
<
C
DataType
>
c_m_n_host_result
(
Tensor
<
E
DataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
C
DataType
,
E
DataType
,
AccDataType
,
AElementOp
,
BElementOp
,
...
...
@@ -294,7 +256,8 @@ void profile_grouped_gemm_impl(int do_verification,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
].
mData
,
c_m_n_host_result
.
mData
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
].
mData
,
c_m_n_host_result
.
mData
);
if
(
do_log
)
{
...
...
@@ -319,6 +282,8 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
}
// namespace profiler
}
// namespace profiler
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment