Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4fe62ed
"vscode:/vscode.git/clone" did not exist on "886575ee43c3e7060d74e2feb2018111e0998013"
Commit
a4fe62ed
authored
Sep 25, 2024
by
Mirza Halilcevic
Browse files
Merge remote-tracking branch 'upstream/develop' into ck_migraphx_integration
parents
08255e1b
3528a523
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1272 additions
and
343 deletions
+1272
-343
Jenkinsfile
Jenkinsfile
+92
-19
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+42
-185
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+354
-70
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+326
-50
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+17
-0
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
...tion/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
+236
-0
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+1
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+21
-5
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+4
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+3
-3
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+34
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
...ion_instance/gpu/grouped_convolution_forward_comp_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
...peration_instance/gpu/grouped_convolution_forward_xdl.inc
+33
-0
No files found.
Jenkinsfile
View file @
a4fe62ed
...
@@ -100,7 +100,15 @@ def getDockerImage(Map conf=[:]){
...
@@ -100,7 +100,15 @@ def getDockerImage(Map conf=[:]){
dockerArgs
=
dockerArgs
+
" --no-cache "
dockerArgs
=
dockerArgs
+
" --no-cache "
}
}
echo
"Docker Args: ${dockerArgs}"
echo
"Docker Args: ${dockerArgs}"
def
image
=
getDockerImageName
()
def
image
if
(
params
.
BUILD_LEGACY_OS
&&
conf
.
get
(
"docker_name"
,
""
)
!=
""
){
image
=
conf
.
get
(
"docker_name"
,
""
)
echo
"Using legacy docker: ${image}"
}
else
{
image
=
getDockerImageName
()
echo
"Using default docker: ${image}"
}
//Check if image exists
//Check if image exists
def
retimage
def
retimage
try
try
...
@@ -125,7 +133,9 @@ def buildDocker(install_prefix){
...
@@ -125,7 +133,9 @@ def buildDocker(install_prefix){
def
image_name
=
getDockerImageName
()
def
image_name
=
getDockerImageName
()
echo
"Building Docker for ${image_name}"
echo
"Building Docker for ${image_name}"
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
if
(
params
.
COMPILER_VERSION
==
"amd-staging"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
dockerArgs
=
dockerArgs
+
" --no-cache "
}
echo
"Build Args: ${dockerArgs}"
echo
"Build Args: ${dockerArgs}"
try
{
try
{
if
(
params
.
BUILD_DOCKER
){
if
(
params
.
BUILD_DOCKER
){
...
@@ -259,6 +269,7 @@ def cmake_build(Map conf=[:]){
...
@@ -259,6 +269,7 @@ def cmake_build(Map conf=[:]){
"""
)
"""
)
sh
cmd3
sh
cmd3
}
}
// reduce parallelism when compiling, clang uses too much memory
// reduce parallelism when compiling, clang uses too much memory
def
nt
=
nthreads
()
def
nt
=
nthreads
()
def
cmd
def
cmd
...
@@ -273,7 +284,7 @@ def cmake_build(Map conf=[:]){
...
@@ -273,7 +284,7 @@ def cmake_build(Map conf=[:]){
}
}
else
{
else
{
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs}
dumb-init
make -j${nt} ${config_targets}"
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} make -j${nt} ${config_targets}"
)
}
}
cmd
=
conf
.
get
(
"cmd"
,
"""
cmd
=
conf
.
get
(
"cmd"
,
"""
${setup_cmd}
${setup_cmd}
...
@@ -292,8 +303,8 @@ def cmake_build(Map conf=[:]){
...
@@ -292,8 +303,8 @@ def cmake_build(Map conf=[:]){
dir
(
"build"
){
dir
(
"build"
){
//build CK
//build CK
sh
cmd
sh
cmd
//run tests
//run tests
except when NO_CK_BUILD or BUILD_LEGACY_OS are set
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)
&&
!
params
.
BUILD_LEGACY_OS
){
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
sh
"/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
sh
"/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts
"ck_build_trace.json"
archiveArtifacts
"ck_build_trace.json"
...
@@ -330,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -330,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){
env
.
HSA_ENABLE_SDMA
=
0
env
.
HSA_ENABLE_SDMA
=
0
checkout
scm
checkout
scm
def
image
=
getDockerImageName
()
def
image
if
(
params
.
BUILD_LEGACY_OS
&&
conf
.
get
(
"docker_name"
,
""
)
!=
""
){
image
=
conf
.
get
(
"docker_name"
,
""
)
echo
"Using legacy docker: ${image}"
}
else
{
image
=
getDockerImageName
()
echo
"Using default docker: ${image}"
}
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
// Jenkins is complaining about the render group
// Jenkins is complaining about the render group
...
@@ -512,7 +531,16 @@ def Build_CK(Map conf=[:]){
...
@@ -512,7 +531,16 @@ def Build_CK(Map conf=[:]){
env
.
DOCKER_BUILDKIT
=
1
env
.
DOCKER_BUILDKIT
=
1
checkout
scm
checkout
scm
def
image
=
getDockerImageName
()
def
image
if
(
params
.
BUILD_LEGACY_OS
&&
conf
.
get
(
"docker_name"
,
""
)
!=
""
){
image
=
conf
.
get
(
"docker_name"
,
""
)
echo
"Using legacy docker: ${image}"
}
else
{
image
=
getDockerImageName
()
echo
"Using default docker: ${image}"
}
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
// Jenkins is complaining about the render group
// Jenkins is complaining about the render group
...
@@ -524,6 +552,9 @@ def Build_CK(Map conf=[:]){
...
@@ -524,6 +552,9 @@ def Build_CK(Map conf=[:]){
if
(
params
.
COMPILER_VERSION
==
"amd-staging"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
if
(
params
.
COMPILER_VERSION
==
"amd-staging"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
}
if
(
params
.
BUILD_LEGACY_OS
){
dockerOpts
=
dockerOpts
+
" --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
}
def
video_id
=
sh
(
returnStdout:
true
,
script:
'getent group video | cut -d: -f3'
)
def
video_id
=
sh
(
returnStdout:
true
,
script:
'getent group video | cut -d: -f3'
)
def
render_id
=
sh
(
returnStdout:
true
,
script:
'getent group render | cut -d: -f3'
)
def
render_id
=
sh
(
returnStdout:
true
,
script:
'getent group render | cut -d: -f3'
)
dockerOpts
=
dockerOpts
+
" --group-add=${video_id} --group-add=${render_id} "
dockerOpts
=
dockerOpts
+
" --group-add=${video_id} --group-add=${render_id} "
...
@@ -707,7 +738,8 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
...
@@ -707,7 +738,8 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false'''
:
""
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
0 13 * * * % BUILD_LEGACY_OS=true '''
:
""
pipeline
{
pipeline
{
agent
none
agent
none
...
@@ -794,6 +826,10 @@ pipeline {
...
@@ -794,6 +826,10 @@ pipeline {
name:
"NINJA_BUILD_TRACE"
,
name:
"NINJA_BUILD_TRACE"
,
defaultValue:
false
,
defaultValue:
false
,
description:
"Generate a ninja build trace (default: OFF)"
)
description:
"Generate a ninja build trace (default: OFF)"
)
booleanParam
(
name:
"BUILD_LEGACY_OS"
,
defaultValue:
false
,
description:
"Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)"
)
}
}
environment
{
environment
{
dbuser
=
"${dbuser}"
dbuser
=
"${dbuser}"
...
@@ -946,7 +982,6 @@ pipeline {
...
@@ -946,7 +982,6 @@ pipeline {
{
{
parallel
parallel
{
{
stage
(
"Run CK_TILE_GEMM Tests on gfx90a"
)
stage
(
"Run CK_TILE_GEMM Tests on gfx90a"
)
{
{
when
{
when
{
...
@@ -965,7 +1000,6 @@ pipeline {
...
@@ -965,7 +1000,6 @@ pipeline {
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
)
cleanWs
()
cleanWs
()
}
}
}
}
stage
(
"Run CK_TILE_GEMM Tests on gfx942"
)
stage
(
"Run CK_TILE_GEMM Tests on gfx942"
)
{
{
...
@@ -988,15 +1022,54 @@ pipeline {
...
@@ -988,15 +1022,54 @@ pipeline {
}
}
}
}
}
}
stage
(
"Build CK and run Tests"
)
stage
(
"Build CK and run Tests"
)
{
{
parallel
parallel
{
{
stage
(
"Build CK with RHEL8"
)
{
when
{
beforeAgent
true
expression
{
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
def
docker_name
=
"${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
setup_args
=
""" -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args
=
" "
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
" "
,
no_reboot:
true
,
build_type:
'Release'
,
docker_name:
docker_name
)
cleanWs
()
}
}
stage
(
"Build CK with SLES15"
)
{
when
{
beforeAgent
true
expression
{
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
def
docker_name
=
"${env.CK_DOCKERHUB_PRIVATE}:ck_sles15_rocm6.3"
setup_args
=
""" -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args
=
" "
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
" "
,
no_reboot:
true
,
build_type:
'Release'
,
docker_name:
docker_name
)
cleanWs
()
}
}
stage
(
"Build CK for all gfx9 targets"
)
stage
(
"Build CK for all gfx9 targets"
)
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
RUN_FULL_QA
.
toBoolean
()
}
expression
{
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
environment
{
...
@@ -1018,7 +1091,7 @@ pipeline {
...
@@ -1018,7 +1091,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
RUN_FULL_QA
.
toBoolean
()
}
expression
{
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx942"
)
}
agent
{
label
rocmnode
(
"gfx942"
)
}
environment
{
environment
{
...
@@ -1038,7 +1111,7 @@ pipeline {
...
@@ -1038,7 +1111,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
}
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
environment
{
...
@@ -1058,7 +1131,7 @@ pipeline {
...
@@ -1058,7 +1131,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
&&
!
params
.
RUN_FULL_QA
.
toBoolean
()
}
expression
{
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
&&
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
environment
{
...
@@ -1077,7 +1150,7 @@ pipeline {
...
@@ -1077,7 +1150,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
}
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx1030"
)
}
agent
{
label
rocmnode
(
"gfx1030"
)
}
environment
{
environment
{
...
@@ -1097,7 +1170,7 @@ pipeline {
...
@@ -1097,7 +1170,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
}
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx1101"
)
}
agent
{
label
rocmnode
(
"gfx1101"
)
}
environment
{
environment
{
...
@@ -1117,7 +1190,7 @@ pipeline {
...
@@ -1117,7 +1190,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
BUILD_GFX12
.
toBoolean
()
&&
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
}
expression
{
params
.
BUILD_GFX12
.
toBoolean
()
&&
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
!
params
.
BUILD_INSTANCES_ONLY
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
rocmnode
(
"gfx1201"
)
}
agent
{
label
rocmnode
(
"gfx1201"
)
}
environment
{
environment
{
...
@@ -1144,7 +1217,7 @@ pipeline {
...
@@ -1144,7 +1217,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
}
expression
{
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
options
{
retry
(
1
)
}
options
{
retry
(
1
)
}
agent
{
label
rocmnode
(
"gfx90a"
)}
agent
{
label
rocmnode
(
"gfx90a"
)}
...
@@ -1165,7 +1238,7 @@ pipeline {
...
@@ -1165,7 +1238,7 @@ pipeline {
stage
(
"Process results"
){
stage
(
"Process results"
){
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
}
expression
{
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
&&
!
params
.
BUILD_LEGACY_OS
.
toBoolean
()
}
}
}
agent
{
label
'mici'
}
agent
{
label
'mici'
}
steps
{
steps
{
...
...
docs/sphinx/requirements.in
View file @
a4fe62ed
rocm-docs-core==1.8.
0
rocm-docs-core==1.8.
1
sphinxcontrib-bibtex==2.6.3
sphinxcontrib-bibtex==2.6.3
docs/sphinx/requirements.txt
View file @
a4fe62ed
...
@@ -103,7 +103,7 @@ requests==2.32.3
...
@@ -103,7 +103,7 @@ requests==2.32.3
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==1.8.
0
rocm-docs-core==1.8.
1
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via pybtex
# via pybtex
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
a4fe62ed
...
@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
...
@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
std
::
cout
<<
"The overall perfomance of the GEMM with "
std
::
cout
<<
"The overall perfomance of the GEMM with "
<<
"["
<<
data_type
<<
"]"
<<
"["
<<
data_type
<<
"]"
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
",n:"
<<
N
<<
", k:"
<<
K
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
",
n:"
<<
N
<<
", k:"
<<
K
<<
"is:
\n
"
;
<<
"
is:
\n
"
;
std
::
cout
<<
"Running time
:"
<<
ave_time
<<
"ms, Throughput"
<<
gb_per_sec
<<
"GB/s
\n
"
std
::
cout
<<
"Running time:
"
<<
ave_time
<<
"ms, Throughput
"
<<
gb_per_sec
<<
"GB/s
\n
"
<<
std
::
flush
;
<<
std
::
flush
;
return
ave_time
;
return
ave_time
;
...
@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
...
@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
fals
e
;
constexpr
bool
kPadC
=
tru
e
;
// This part comes from the Codegen
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
@@ -348,7 +348,7 @@ int main(int argc, char* argv[])
...
@@ -348,7 +348,7 @@ int main(int argc, char* argv[])
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
std
::
cout
<<
"The GPU veification result is:"
<<
(
pass_gpu
?
"correct"
:
"fail"
)
std
::
cout
<<
"The GPU veification result is:
"
<<
(
pass_gpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
<<
std
::
flush
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
a4fe62ed
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
@@ -22,7 +23,6 @@
...
@@ -22,7 +23,6 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
KPerBlock
/
K1Number
,
KPerBlock
/
K1Number
,
ConvBackwardWeightSpecialization
>
{};
ConvBackwardWeightSpecialization
>
{};
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
InLayout
,
WeiLayout
,
OutLayout
,
NDimSpatial
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
static
constexpr
GemmSpecialization
GemmSpec
=
GemmSpecialization
::
Default
;
static
constexpr
GemmSpecialization
GemmSpec
=
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
...
@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch
)[
I2
];
batch
)[
I2
];
}
}
static
constexpr
index_t
ClusterLengthMPerBlock
=
using
NGCHWTransposeDescType
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
static
constexpr
index_t
ClusterLengthNPerBlock
=
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
using
InputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeInputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
OutputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeOutputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
...
@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1
>
;
I1
>
;
using
GridwiseElementwiseTranspose
=
using
GridwiseElementwiseTranspose
=
GridwiseElementwise
<
Tuple
<
Input
TransposeDescType
>
,
GridwiseElementwise
<
Tuple
<
NGCHW
TransposeDescType
>
,
Tuple
<
Output
TransposeDescType
>
,
Tuple
<
NHWGC
TransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
...
@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin
(
output_spatial_lengths_
));
begin
(
output_spatial_lengths_
));
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_n_c_wis_strides_transposed
=
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_n_c_wis_strides_transposed
=
b_g_n_c_wis_strides
;
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_transposed
=
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_transposed
=
a_g_n_k_wos_strides
;
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
// NGKHW - transpose needed
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
b_g_n_c_wis_strides_transposed
[
I0
]
=
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
a_g_n_k_wos_strides_transposed
[
I0
]
=
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_K_
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_K_
;
}
}
const
auto
descs
=
const
auto
descs
=
conv_to_gemm_transformer_v2
conv_to_gemm_transformer_v2
...
@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
{
a_in_transpose_desc_
=
a_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
a_out_transpose_desc_
=
a_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
b_in_transpose_desc_
=
b_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
b_out_transpose_desc_
=
b_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
...
@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_b_
;
elementwise_block_2_ctile_map_transpose_b_
;
Input
TransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
NGCHW
TransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
Output
TransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
NHWGC
TransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
...
@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
sizeof
(
BDataType
);
// Different data type for A and B is not supported
auto
kernel_transpose
=
kernel_elementwise_dual
<
GridwiseElementwiseTranspose
,
auto
kernel_transpose
=
kernel_elementwise_dual
<
GridwiseElementwiseTranspose
,
ck
::
Tuple
<
Input
TransposeDescType
>
,
ck
::
Tuple
<
NGCHW
TransposeDescType
>
,
ck
::
Tuple
<
Input
TransposeDescType
>
,
ck
::
Tuple
<
NGCHW
TransposeDescType
>
,
ck
::
Tuple
<
Output
TransposeDescType
>
,
ck
::
Tuple
<
NHWGC
TransposeDescType
>
,
ck
::
Tuple
<
Output
TransposeDescType
>
,
ck
::
Tuple
<
NHWGC
TransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
B
DataType
*>
,
ck
::
Tuple
<
A
DataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
element_wise
::
PassThrough
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
a4fe62ed
...
@@ -15,9 +15,11 @@
...
@@ -15,9 +15,11 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...
@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
// NGCHW is not supported for multiAB
static_assert
(
!
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
||
!
(
isMultiA
||
isMultiB
));
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
...
@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
EDataType
,
EDataType
,
NumGroupsToMerge
>
;
NumGroupsToMerge
>
;
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
ALayout
,
BLayout
,
ELayout
,
NDimSpatial
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGC
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGC
,
ALay
>>
;
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
A
Lay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
Lay
out
>();
const
auto
in_gemmm_gemmk_desc
=
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template
<
typename
ELay
>
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGK
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGK
,
ELay
>>
;
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
Lay
out
>();
const
auto
out_gemmm_gemmn_desc
=
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
// block-to-e-tile map
using
Block2ETileMap
=
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
NPerBlock
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
constexpr
index_t
ElementwiseBlocksize
=
ClusterLengthNPerBlock
*
ClusterLengthNPerBlock
;
using
GridwiseElementwiseInputTranspose
=
GridwiseElementwise
<
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I1
,
I0
>
;
using
GridwiseElementwiseOutputTranspose
=
GridwiseElementwise
<
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
const
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I0
,
I1
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_bs_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_c_wis_strides
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
)},
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
e_g_n_k_wos_strides
,
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
conv_filter_strides
,
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
conv_filter_dilations
,
e_g_n_k_wos_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
input_left_pads
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
input_right_pads
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
num_group_
{
a_g_n_c_wis_lengths_
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_m_k_
{
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_
)},
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_
)},
...
@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_
{},
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
}
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
{
// A/B/E Batch Stride
// A/B/E Batch Stride
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
...
@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
a_g_n_c_wis_strides
_
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
// type is not tuple)
...
@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// in case of MultiA is false but isMultiB is true
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_
.
BatchStrideA_
(
i
)
=
compute_ptr_offset_of_n_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
}
}
else
else
{
{
// if MultiB and not MultiA then p_as is single pointer
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
compute_ptr_offset_of_n_
.
BatchStrideA_
=
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
}
}
});
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
b_g_k_c_xs_strides
_
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// It is possible that one of the AB is a pointer and one is a tuple.
...
@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
else
{
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
a_g_n_c_wis_strides
_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
b_g_k_c_xs_strides_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides_
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
...
@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
ds_g_n_k_wos_strides
_
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ds_g_n_k_wos_strides
_
[
i
][
1
]
*
conv_N_per_block_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
_
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
_
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
_
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
_
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_lengths
_
,
ds_g_n_k_wos_strides
[
i
],
ds_g_n_k_wos_strides
_
[
i
],
conv_filter_strides
,
conv_filter_strides
_
,
conv_filter_dilations
,
conv_filter_dilations
_
,
input_left_pads
,
input_left_pads
_
,
input_right_pads
};
input_right_pads
_
};
// D desc
// D desc
ds_grid_desc_m_n_
(
i
)
=
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
e_g_n_k_wos_strides_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides_
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
...
@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
);
ds_grid_desc_m_n_
);
}
}
}
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
// Use not modified base strides
a_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
a_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
e_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
e_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_e_
=
Block2TileMapElementwise
{
e_in_transpose_desc_
.
GetLength
(
I0
),
e_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
EDataType
)
*
e_out_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
0
;
}
}
}
void
Print
()
const
void
Print
()
const
...
@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
...
@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_e_
;
NGCHWTransposeDescType
a_in_transpose_desc_
,
e_out_transpose_desc_
;
NHWGCTransposeDescType
a_out_transpose_desc_
,
e_in_transpose_desc_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
...
@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
};
// Invoker
// Invoker
...
@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
Gemm
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
...
@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
else
else
{
{
const
ADataType
*
p_a_grid
=
arg
.
p_as_grid_
.
At
(
I0
);
EDataType
*
p_e_grid
=
arg
.
p_e_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
);
p_e_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
}
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
const
ADataType
*
,
const
ADataType
*
,
...
@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a
s
_grid
_
.
At
(
I0
),
// Pass just A descriptor instead of tuple
p_a_grid
,
// Pass just A descriptor instead of tuple
arg
.
p_bs_grid_
.
At
(
I0
),
// Pass just B descriptor instead of tuple
arg
.
p_bs_grid_
.
At
(
I0
),
// Pass just B descriptor instead of tuple
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid
_
,
p_e_grid
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
...
@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
);
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseInputTranspose
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
p_as_grid_
.
At
(
I0
)),
make_tuple
(
p_a_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
element_wise
::
PassThrough
{});
}
avg_time
+=
RunGemm
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_e_
.
CalculateGridSize
(
arg
.
e_in_transpose_desc_
);
const
EDataType
*
p_e_out_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
EDataType
*
p_e_in_grid
=
arg
.
p_e_grid_
;
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseOutputTranspose
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
const
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
e_in_transpose_desc_
),
make_tuple
(
arg
.
e_out_transpose_desc_
),
make_tuple
(
p_e_out_grid
),
make_tuple
(
p_e_in_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_e_
,
element_wise
::
PassThrough
{});
}
return
avg_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
...
@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
return
false
;
return
false
;
}
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
(
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCSpatial_GKSpatial_NGKSpatial
<
ALayout
,
BLayout
,
ELayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NGCW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCHW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCDHW
>
)
{
{
// Check access per C
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
// If not possible, check access per G
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
(
C
==
1
||
NumGroupsToMerge
==
1
)
&&
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
&&
(
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCSpatial_GKSpatial_NGKSpatial
<
ALayout
,
BLayout
,
ELayout
>
())
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
...
@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
});
});
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
if
((
G
*
C
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
((
G
*
K
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
a_g_n_c_wis_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
e_g_n_k_wos_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
if
(
!
valid
)
if
(
!
valid
)
{
{
return
false
;
return
false
;
...
@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NGKW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKHW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKDHW
>
)
{
{
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
...
@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
str
.
str
();
return
str
.
str
();
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
GetWorkspaceSizeBytes
();
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
a4fe62ed
...
@@ -15,10 +15,12 @@
...
@@ -15,10 +15,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
...
@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
ALayout
,
BLayout
,
ELayout
,
NDimSpatial
,
MPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGC
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGC
,
ALay
>>
;
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
A
Lay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
Lay
out
>();
const
auto
in_gemmm_gemmk_desc
=
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGK
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGK
,
ELay
>>
;
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
Lay
out
>();
const
auto
out_gemmm_gemmn_desc
=
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// Use appropriate gridwise gemm
// Use appropriate gridwise gemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
GridwiseGemmV3TemplateParams
>
;
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
GridwiseGemmV3TemplateParams
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
NPerBlock
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
constexpr
index_t
ElementwiseBlocksize
=
ClusterLengthNPerBlock
*
ClusterLengthNPerBlock
;
using
GridwiseElementwiseInputTranspose
=
GridwiseElementwise
<
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I1
,
I0
>
;
using
GridwiseElementwiseOutputTranspose
=
GridwiseElementwise
<
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
const
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I0
,
I1
>
;
static
auto
static
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
{
...
@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
:
p_a_grid_
{},
:
p_a_grid_
{},
p_b_grid_
{},
p_b_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_c_wis_strides
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
)},
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths
,
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides
,
e_g_n_k_wos_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
conv_filter_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
conv_filter_dilations
,
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads
,
conv_filter_dilations_
{
conv_filter_dilations
},
input_right_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
num_group_
{
a_g_n_c_wis_lengths_
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
...
@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
compute_ptr_offset_of_n_
{},
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
}
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
{
// A/B/E Batch/N Stride
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
_
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
_
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
// p_as and p_bs are pointers
p_a_grid_
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_a_grid_
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_b_grid_
=
static_cast
<
const
BDataType
*>
(
p_bs
);
p_b_grid_
=
static_cast
<
const
BDataType
*>
(
p_bs
);
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
_
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
_
[
1
]
*
conv_N_per_block_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
// Use not modified base strides
a_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
a_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
e_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
e_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_e_
=
Block2TileMapElementwise
{
e_in_transpose_desc_
.
GetLength
(
I0
),
e_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
EDataType
)
*
e_out_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
0
;
}
}
}
void
Print
()
const
void
Print
()
const
...
@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
...
@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
// block-to-e-tile map
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
elementwise_block_2_ctile_map_transpose_e_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
NGCHWTransposeDescType
a_in_transpose_desc_
,
e_out_transpose_desc_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
NHWGCTransposeDescType
a_out_transpose_desc_
,
e_in_transpose_desc_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
};
// Invoker
// Invoker
...
@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
Gemm
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
...
@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
index_t
K_split
=
(
GemmK
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
index_t
K_split
=
(
GemmK
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
ADataType
*
p_a_grid
=
arg
.
p_a_grid_
;
EDataType
*
p_e_grid
=
arg
.
p_e_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
);
p_e_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
}
typename
GridwiseGemm
::
Argument
gemm_arg
{
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid
_
,
arg
.
p_b_grid_
,
arg
.
p_e_grid
_
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
p_a_grid
,
arg
.
p_b_grid_
,
p_e_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
if
(
stream_config
.
flush_cache
)
...
@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
ave_time
;
return
ave_time
;
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
);
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseInputTranspose
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
p_a_grid_
),
make_tuple
(
p_a_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
element_wise
::
PassThrough
{});
}
avg_time
+=
RunGemm
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_e_
.
CalculateGridSize
(
arg
.
e_in_transpose_desc_
);
const
EDataType
*
p_e_out_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
EDataType
*
p_e_in_grid
=
arg
.
p_e_grid_
;
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseOutputTranspose
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
const
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
e_in_transpose_desc_
),
make_tuple
(
arg
.
e_out_transpose_desc_
),
make_tuple
(
p_e_out_grid
),
make_tuple
(
p_e_in_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_e_
,
element_wise
::
PassThrough
{});
}
return
avg_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
...
@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
// check device
if
(
get_device_name
()
==
"gfx908"
)
if
(
get_device_name
()
==
"gfx908"
)
{
{
...
@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NGCW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCHW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCDHW
>
)
{
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
...
@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
...
@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
false
;
return
false
;
}
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
if
((
G
*
C
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
((
G
*
K
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
a_g_n_c_wis_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
e_g_n_k_wos_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
// check vector access of E
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NGKW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKHW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKDHW
>
)
{
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
return
false
;
return
false
;
...
@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
str
.
str
();
return
str
.
str
();
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
GetWorkspaceSizeBytes
();
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
a4fe62ed
...
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
...
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
}
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCW_GKXC_NGKW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKW
>
;
}
// 2d
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWGC_GKYXC_NHWGK
()
constexpr
bool
is_NHWGC_GKYXC_NHWGK
()
...
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
...
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK
<
InLayout
,
WeiLayout
,
OutLayout
>
();
is_GNDHWC_GKZYXC_GNDHWK
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCSpatial_GKSpatial_NGKSpatial
()
{
return
is_NGCW_GKXC_NGKW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
0 → 100644
View file @
a4fe62ed
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace
ck
{
namespace
tensor_operation
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
index_t
NDimSpatial
,
index_t
MPerThread
,
index_t
NPerThread
>
struct
TransformConvNGCHWToNHWGC
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I3
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
static
auto
TransposeStrides
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_strides
)
{
if
constexpr
(
device
::
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
device
::
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides_transposed
;
const
auto
G
=
g_n_c_wis_lengths
[
I0
];
const
auto
C
=
g_n_c_wis_lengths
[
I2
];
g_n_c_wis_strides_transposed
[
I0
]
=
C
;
g_n_c_wis_strides_transposed
[
I1
]
=
g_n_c_wis_strides
[
I1
];
g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
G
*
C
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I5
]
=
G
*
C
;
}
return
g_n_c_wis_strides_transposed
;
}
else
{
// transpose not needed
return
g_n_c_wis_strides
;
}
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/utility/reduction_operator.hpp
View file @
a4fe62ed
...
@@ -516,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
...
@@ -516,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
static
constexpr
bool
value
=
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
is_same
<
DataType
,
int32_t
>::
value
;
};
};
}
// namespace reduce
}
// namespace reduce
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
a4fe62ed
...
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if
masked and
no work to do
.
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
a4fe62ed
...
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit
// check early exit
if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
a4fe62ed
...
@@ -123,14 +123,26 @@ struct GemmKernel
...
@@ -123,14 +123,26 @@ struct GemmKernel
}
}
}();
}();
auto
ABlockWindow
=
make_tile_windo
w
(
auto
a_pad_view
=
pad_tensor_vie
w
(
a_tensor_view
,
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadA
?
1
:
0
>
{});
auto
ABlockWindow
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
{
i_m
,
0
});
auto
BBlockWindow
=
make_tile_windo
w
(
auto
b_pad_view
=
pad_tensor_vie
w
(
b_tensor_view
,
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadB
?
1
:
0
>
{});
auto
BBlockWindow
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
{
i_n
,
0
});
// allocate LDS
// allocate LDS
...
@@ -163,12 +175,16 @@ struct GemmKernel
...
@@ -163,12 +175,16 @@ struct GemmKernel
}
}
}();
}();
auto
CBlockWindow
=
make_tile_windo
w
(
auto
c_pad_view
=
pad_tensor_vie
w
(
c_tensor_view
,
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadC
?
1
:
0
>
{});
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
// epilogue.
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
EpiloguePipeline
{}(
CBlockWindow
,
acc
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
a4fe62ed
...
@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
...
@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
AlignmentB
=
Problem
::
AlignmentB
;
static
constexpr
index_t
AlignmentB
=
Problem
::
AlignmentB
;
static
constexpr
index_t
AlignmentC
=
Problem
::
AlignmentC
;
static
constexpr
index_t
AlignmentC
=
Problem
::
AlignmentC
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
{
return
ck_tile
::
integer_divide_ceil
(
return
ck_tile
::
integer_divide_ceil
(
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
a4fe62ed
...
@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
...
@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
VectorLoadSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
VectorLoadSize
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
static
constexpr
index_t
AlignmentC
=
kPadC
?
VectorLoadSize
/
sizeof
(
CDataType
)
:
1
;
static
constexpr
index_t
AlignmentC
=
kPadC
?
1
:
VectorLoadSize
/
sizeof
(
CDataType
);
};
};
}
// namespace ck_tile
}
// namespace ck_tile
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
a4fe62ed
...
@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
AComputeType
,
float
>
&&
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
AComputeType
,
half_t
>
&&
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
op_ptrs
);
}
#endif
}
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
View file @
a4fe62ed
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
View file @
a4fe62ed
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
View file @
a4fe62ed
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
...
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
View file @
a4fe62ed
...
@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
...
@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment