"vscode:/vscode.git/clone" did not exist on "886575ee43c3e7060d74e2feb2018111e0998013"
Commit a4fe62ed authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Merge remote-tracking branch 'upstream/develop' into ck_migraphx_integration

parents 08255e1b 3528a523
...@@ -100,7 +100,15 @@ def getDockerImage(Map conf=[:]){ ...@@ -100,7 +100,15 @@ def getDockerImage(Map conf=[:]){
dockerArgs = dockerArgs + " --no-cache " dockerArgs = dockerArgs + " --no-cache "
} }
echo "Docker Args: ${dockerArgs}" echo "Docker Args: ${dockerArgs}"
def image = getDockerImageName() def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
//Check if image exists //Check if image exists
def retimage def retimage
try try
...@@ -125,7 +133,9 @@ def buildDocker(install_prefix){ ...@@ -125,7 +133,9 @@ def buildDocker(install_prefix){
def image_name = getDockerImageName() def image_name = getDockerImageName()
echo "Building Docker for ${image_name}" echo "Building Docker for ${image_name}"
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' " def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerArgs = dockerArgs + " --no-cache "
}
echo "Build Args: ${dockerArgs}" echo "Build Args: ${dockerArgs}"
try{ try{
if(params.BUILD_DOCKER){ if(params.BUILD_DOCKER){
...@@ -259,6 +269,7 @@ def cmake_build(Map conf=[:]){ ...@@ -259,6 +269,7 @@ def cmake_build(Map conf=[:]){
""") """)
sh cmd3 sh cmd3
} }
// reduce parallelism when compiling, clang uses too much memory // reduce parallelism when compiling, clang uses too much memory
def nt = nthreads() def nt = nthreads()
def cmd def cmd
...@@ -273,7 +284,7 @@ def cmake_build(Map conf=[:]){ ...@@ -273,7 +284,7 @@ def cmake_build(Map conf=[:]){
} }
else{ else{
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}")
} }
cmd = conf.get("cmd", """ cmd = conf.get("cmd", """
${setup_cmd} ${setup_cmd}
...@@ -292,8 +303,8 @@ def cmake_build(Map conf=[:]){ ...@@ -292,8 +303,8 @@ def cmake_build(Map conf=[:]){
dir("build"){ dir("build"){
//build CK //build CK
sh cmd sh cmd
//run tests //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
if(!setup_args.contains("NO_CK_BUILD")){ if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts "ck_build_trace.json" archiveArtifacts "ck_build_trace.json"
...@@ -330,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){ ...@@ -330,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){
env.HSA_ENABLE_SDMA=0 env.HSA_ENABLE_SDMA=0
checkout scm checkout scm
def image = getDockerImageName() def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
...@@ -512,7 +531,16 @@ def Build_CK(Map conf=[:]){ ...@@ -512,7 +531,16 @@ def Build_CK(Map conf=[:]){
env.DOCKER_BUILDKIT=1 env.DOCKER_BUILDKIT=1
checkout scm checkout scm
def image = getDockerImageName() def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
...@@ -524,6 +552,9 @@ def Build_CK(Map conf=[:]){ ...@@ -524,6 +552,9 @@ def Build_CK(Map conf=[:]){
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
} }
if(params.BUILD_LEGACY_OS){
dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
}
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
...@@ -707,7 +738,8 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM ...@@ -707,7 +738,8 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : "" 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
0 13 * * * % BUILD_LEGACY_OS=true ''' : ""
pipeline { pipeline {
agent none agent none
...@@ -794,6 +826,10 @@ pipeline { ...@@ -794,6 +826,10 @@ pipeline {
name: "NINJA_BUILD_TRACE", name: "NINJA_BUILD_TRACE",
defaultValue: false, defaultValue: false,
description: "Generate a ninja build trace (default: OFF)") description: "Generate a ninja build trace (default: OFF)")
booleanParam(
name: "BUILD_LEGACY_OS",
defaultValue: false,
description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)")
} }
environment{ environment{
dbuser = "${dbuser}" dbuser = "${dbuser}"
...@@ -946,7 +982,6 @@ pipeline { ...@@ -946,7 +982,6 @@ pipeline {
{ {
parallel parallel
{ {
stage("Run CK_TILE_GEMM Tests on gfx90a") stage("Run CK_TILE_GEMM Tests on gfx90a")
{ {
when { when {
...@@ -965,7 +1000,6 @@ pipeline { ...@@ -965,7 +1000,6 @@ pipeline {
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs() cleanWs()
} }
} }
stage("Run CK_TILE_GEMM Tests on gfx942") stage("Run CK_TILE_GEMM Tests on gfx942")
{ {
...@@ -988,15 +1022,54 @@ pipeline { ...@@ -988,15 +1022,54 @@ pipeline {
} }
} }
} }
stage("Build CK and run Tests") stage("Build CK and run Tests")
{ {
parallel parallel
{ {
stage("Build CK with RHEL8")
{
when {
beforeAgent true
expression { params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
setup_args = """ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args = " "
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
cleanWs()
}
}
stage("Build CK with SLES15")
{
when {
beforeAgent true
expression { params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_sles15_rocm6.3"
setup_args = """ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args = " "
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
cleanWs()
}
}
stage("Build CK for all gfx9 targets") stage("Build CK for all gfx9 targets")
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() } expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
...@@ -1018,7 +1091,7 @@ pipeline { ...@@ -1018,7 +1091,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() } expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx942") } agent{ label rocmnode("gfx942") }
environment{ environment{
...@@ -1038,7 +1111,7 @@ pipeline { ...@@ -1038,7 +1111,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
...@@ -1058,7 +1131,7 @@ pipeline { ...@@ -1058,7 +1131,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() } expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
...@@ -1077,7 +1150,7 @@ pipeline { ...@@ -1077,7 +1150,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx1030") } agent{ label rocmnode("gfx1030") }
environment{ environment{
...@@ -1097,7 +1170,7 @@ pipeline { ...@@ -1097,7 +1170,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx1101") } agent{ label rocmnode("gfx1101") }
environment{ environment{
...@@ -1117,7 +1190,7 @@ pipeline { ...@@ -1117,7 +1190,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent{ label rocmnode("gfx1201") } agent{ label rocmnode("gfx1201") }
environment{ environment{
...@@ -1144,7 +1217,7 @@ pipeline { ...@@ -1144,7 +1217,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
options { retry(1) } options { retry(1) }
agent{ label rocmnode("gfx90a")} agent{ label rocmnode("gfx90a")}
...@@ -1165,7 +1238,7 @@ pipeline { ...@@ -1165,7 +1238,7 @@ pipeline {
stage("Process results"){ stage("Process results"){
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
} }
agent { label 'mici' } agent { label 'mici' }
steps{ steps{
......
rocm-docs-core==1.8.0 rocm-docs-core==1.8.1
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.8.0 rocm-docs-core==1.8.1
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf, ...@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
std::cout << "The overall perfomance of the GEMM with " std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]" << "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K << "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
<< "is: \n"; << " is: \n";
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
<< std::flush; << std::flush;
return ave_time; return ave_time;
...@@ -235,7 +235,7 @@ int main(int argc, char* argv[]) ...@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kPadC = false; constexpr bool kPadC = true;
// This part comes from the Codegen // This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
...@@ -348,7 +348,7 @@ int main(int argc, char* argv[]) ...@@ -348,7 +348,7 @@ int main(int argc, char* argv[])
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref); pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is:" << (pass_gpu ? "correct" : "fail") std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush; << std::flush;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -22,7 +23,6 @@ ...@@ -22,7 +23,6 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp> #include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
KPerBlock / K1Number, KPerBlock / K1Number,
ConvBackwardWeightSpecialization>{}; ConvBackwardWeightSpecialization>{};
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<InLayout,
WeiLayout,
OutLayout,
NDimSpatial,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch)[I2]; batch)[I2];
} }
static constexpr index_t ClusterLengthMPerBlock = using NGCHWTransposeDescType =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
static constexpr index_t ClusterLengthNPerBlock = .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& HiStride = g_n_c_wis_strides[3];
const index_t& WiStride = g_n_c_wis_strides[4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& DiStride = g_n_c_wis_strides[3];
const index_t& HiStride = g_n_c_wis_strides[4];
const index_t& WiStride = g_n_c_wis_strides[5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
using InputTransposeDescType =
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
using OutputTransposeDescType =
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
...@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1>; I1>;
using GridwiseElementwiseTranspose = using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<InputTransposeDescType>, GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<OutputTransposeDescType>, Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>, Tuple<const ADataType*>,
Tuple<ADataType*>, Tuple<ADataType*>,
Block2TileMapElementwise, Block2TileMapElementwise,
...@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin(output_spatial_lengths_)); begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed = std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
b_g_n_c_wis_strides; conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed = std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
a_g_n_k_wos_strides; conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
// NGKHW - transpose needed
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
b_g_n_c_wis_strides_transposed[I2] = I1;
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
a_g_n_k_wos_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] =
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
}
else if constexpr(NDimSpatial == 3)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
input_spatial_lengths_[I2] * Conv_G_ *
Conv_K_;
a_g_n_k_wos_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
}
}
const auto descs = const auto descs =
conv_to_gemm_transformer_v2 conv_to_gemm_transformer_v2
...@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{ {
a_in_transpose_desc_ = a_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
a_out_transpose_desc_ = a_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
b_in_transpose_desc_ = b_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides); conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
b_out_transpose_desc_ = b_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides); conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
...@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_; elementwise_block_2_ctile_map_transpose_b_;
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
...@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType); sizeof(BDataType);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose, auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
ck::Tuple<InputTransposeDescType>, ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<InputTransposeDescType>, ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<OutputTransposeDescType>, ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<OutputTransposeDescType>, ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>, ck::Tuple<const ADataType*>,
ck::Tuple<BDataType*>, ck::Tuple<ADataType*>,
Block2TileMapElementwise, Block2TileMapElementwise,
Block2TileMapElementwise, Block2TileMapElementwise,
element_wise::PassThrough>; element_wise::PassThrough>;
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value; static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value; static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
// NGCHW is not supported for multiAB
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
!(isMultiA || isMultiB));
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>(); static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>(); static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization, ConvForwardSpecialization,
...@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
EDataType, EDataType,
NumGroupsToMerge>; NumGroupsToMerge>;
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<ALayout,
BLayout,
ELayout,
NDimSpatial,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay> template <typename ELay>
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I1,
I0>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
Tuple<const EDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_bs_grid_{}, p_bs_grid_{},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths, a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_strides, a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides, b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths, ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
e_g_n_k_wos_strides, ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
conv_filter_strides, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
conv_filter_dilations, e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
input_left_pads, e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
input_right_pads}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
num_group_{a_g_n_c_wis_lengths_[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
conv_N_per_block_{conv_to_gemm_transformer_.N_}, conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_m_k_{ a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)}, DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
...@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_{}, compute_ptr_offset_of_n_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{ {
// A/B/E Batch Stride // A/B/E Batch Stride
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
...@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB // Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideA_(i) = compute_ptr_offset_of_groups_.BatchStrideA_(i) =
a_g_n_c_wis_strides[0] * NumGroupsToMerge; a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple) // type is not tuple)
...@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// in case of MultiA is false but isMultiB is true // in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple. // BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_.BatchStrideA_(i) = compute_ptr_offset_of_n_.BatchStrideA_(i) =
a_g_n_c_wis_strides[1] * conv_N_per_block_; a_g_n_c_wis_strides_[1] * conv_N_per_block_;
} }
else else
{ {
// if MultiB and not MultiA then p_as is single pointer // if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as); p_as_grid_(i) = static_cast<const DataType*>(p_as);
compute_ptr_offset_of_n_.BatchStrideA_ = compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_c_wis_strides[1] * conv_N_per_block_; a_g_n_c_wis_strides_[1] * conv_N_per_block_;
} }
}); });
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB // Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideB_(i) = compute_ptr_offset_of_groups_.BatchStrideB_(i) =
b_g_k_c_xs_strides[0] * NumGroupsToMerge; b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>; using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple. // It is possible that one of the AB is a pointer and one is a tuple.
...@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else else
{ {
compute_ptr_offset_of_groups_.BatchStrideA_ = compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides[0] * NumGroupsToMerge; a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ = compute_ptr_offset_of_groups_.BatchStrideB_ =
b_g_k_c_xs_strides[0] * NumGroupsToMerge; b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as); p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
...@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride // D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge; ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) = compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides, a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides, b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths_,
ds_g_n_k_wos_strides[i], ds_g_n_k_wos_strides_[i],
conv_filter_strides, conv_filter_strides_,
conv_filter_dilations, conv_filter_dilations_,
input_left_pads, input_left_pads_,
input_right_pads}; input_right_pads_};
// D desc // D desc
ds_grid_desc_m_n_(i) = ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
}); });
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_groups_.BatchStrideE_ =
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
// populate desc for Ds/E // populate desc for Ds/E
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
...@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
} }
} }
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
a_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
} }
void Print() const void Print() const
...@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
...@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor> ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
...@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
}; };
// Invoker // Invoker
...@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
...@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
else else
{ {
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
const ADataType*, const ADataType*,
...@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple p_a_grid, // Pass just A descriptor instead of tuple
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, p_e_grid,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
...@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
} }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.p_as_grid_.At(I0)),
make_tuple(p_a_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
...@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
return false; return false;
} }
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>()) if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
{ {
return false; return false;
} }
...@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> || is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> || is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> || is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>) is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{ {
// Check access per C // Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{ {
// If not possible, check access per G // If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 && if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() && (is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
G % ABlockTransferSrcScalarPerVector == 0)) G % ABlockTransferSrcScalarPerVector == 0))
{ {
return false; return false;
...@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
}); });
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
if(!valid) if(!valid)
{ {
return false; return false;
...@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> || is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> || is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> || is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>) is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{ {
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
...@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return str.str(); return str.str();
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
}; };
} // namespace device } // namespace device
......
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization, ConvForwardSpecialization,
...@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<ALayout,
BLayout,
ELayout,
NDimSpatial,
MPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// Use appropriate gridwise gemm // Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>; using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I1,
I0>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
Tuple<const EDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
static auto static auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
...@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
: p_a_grid_{}, : p_a_grid_{},
p_b_grid_{}, p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths, a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_strides, a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides, b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides, e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
conv_filter_strides, e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_dilations, conv_filter_strides_{conv_filter_strides},
input_left_pads, conv_filter_dilations_{conv_filter_dilations},
input_right_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
num_group_{a_g_n_c_wis_lengths_[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
conv_N_per_block_{conv_to_gemm_transformer_.N_}, conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)}, MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
...@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
compute_ptr_offset_of_n_{}, compute_ptr_offset_of_n_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{ {
// A/B/E Batch/N Stride // A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as); p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs); p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
a_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
} }
void Print() const void Print() const
...@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
...@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument() // block-to-e-tile map
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_; Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_; elementwise_block_2_ctile_map_transpose_e_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_; NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_; NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
}; };
// Invoker // Invoker
...@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
...@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const ADataType* p_a_grid = arg.p_a_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
}
typename GridwiseGemm::Argument gemm_arg{ typename GridwiseGemm::Argument gemm_arg{
arg.p_a_grid_, arg.p_b_grid_, arg.p_e_grid_, GemmM, GemmN, GemmK, I0, I0, I0, I1}; p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache) if(stream_config.flush_cache)
...@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return ave_time; return ave_time;
} }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(p_a_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
...@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{ {
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
// check device // check device
if(get_device_name() == "gfx908") if(get_device_name() == "gfx908")
{ {
...@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> || is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> || is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> || is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>) is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{ {
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{ {
return false; return false;
...@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v<BLayout, ctc::KZYXGC>) is_same_v<BLayout, ctc::KZYXGC>)
{ {
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{ {
return false; return false;
...@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false; return false;
} }
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check vector access of E // check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> || if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> || is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> || is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> || is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>) is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{ {
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
return false; return false;
...@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return str.str(); return str.str();
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
}; };
} // namespace device } // namespace device
......
...@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK() ...@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>; is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
} }
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCW_GKXC_NGKW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
}
// 2d // 2d
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGC_GKYXC_NHWGK() constexpr bool is_NHWGC_GKYXC_NHWGK()
...@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK() ...@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>(); is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
} }
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
{
return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void> template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
template <typename ALayout,
typename BLayout,
typename ELayout,
index_t NDimSpatial,
index_t MPerThread,
index_t NPerThread>
struct TransformConvNGCHWToNHWGC
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Wi = g_n_c_wis_lengths[I3];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& WiStride = g_n_c_wis_strides[I3];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Wi = g_n_c_wis_lengths[I3];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Hi = g_n_c_wis_lengths[I3];
const index_t& Wi = g_n_c_wis_lengths[I4];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& HiStride = g_n_c_wis_strides[I3];
const index_t& WiStride = g_n_c_wis_strides[I4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Hi = g_n_c_wis_lengths[I3];
const index_t& Wi = g_n_c_wis_lengths[I4];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Di = g_n_c_wis_lengths[I3];
const index_t& Hi = g_n_c_wis_lengths[I4];
const index_t& Wi = g_n_c_wis_lengths[I5];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& DiStride = g_n_c_wis_strides[I3];
const index_t& HiStride = g_n_c_wis_strides[I4];
const index_t& WiStride = g_n_c_wis_strides[I5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Di = g_n_c_wis_lengths[I3];
const index_t& Hi = g_n_c_wis_lengths[I4];
const index_t& Wi = g_n_c_wis_lengths[I5];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
static auto TransposeStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
{
if constexpr(device::is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
device::is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
const auto G = g_n_c_wis_lengths[I0];
const auto C = g_n_c_wis_lengths[I2];
g_n_c_wis_strides_transposed[I0] = C;
g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1];
g_n_c_wis_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C;
g_n_c_wis_strides_transposed[I4] = G * C;
}
else if constexpr(NDimSpatial == 3)
{
g_n_c_wis_strides_transposed[I3] =
g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C;
g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C;
g_n_c_wis_strides_transposed[I5] = G * C;
}
return g_n_c_wis_strides_transposed;
}
else
{
// transpose not needed
return g_n_c_wis_strides;
}
}
};
} // namespace tensor_operation
} // namespace ck
...@@ -516,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, ...@@ -516,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
static constexpr bool value = static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value || is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value || is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value || is_same<DataType, f8_t>::value; is_same<DataType, int32_t>::value;
}; };
} // namespace reduce } // namespace reduce
......
...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if no work to do
if constexpr(FmhaMask::IsMasking) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
{ {
......
...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
......
...@@ -123,14 +123,26 @@ struct GemmKernel ...@@ -123,14 +123,26 @@ struct GemmKernel
} }
}(); }();
auto ABlockWindow = make_tile_window( auto a_pad_view = pad_tensor_view(
a_tensor_view, a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});
auto ABlockWindow = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto BBlockWindow = make_tile_window( auto b_pad_view = pad_tensor_view(
b_tensor_view, b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});
auto BBlockWindow = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
// allocate LDS // allocate LDS
...@@ -163,12 +175,16 @@ struct GemmKernel ...@@ -163,12 +175,16 @@ struct GemmKernel
} }
}(); }();
auto CBlockWindow = make_tile_window( auto c_pad_view = pad_tensor_view(
c_tensor_view, c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
// epilogue. EpiloguePipeline{}(CBlockWindow_pad, acc);
EpiloguePipeline{}(CBlockWindow, acc);
} }
}; };
......
...@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr index_t AlignmentB = Problem::AlignmentB; static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC; static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return ck_tile::integer_divide_ceil(
......
...@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem ...@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
static constexpr bool kPadB = kPadB_; static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1; static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1; static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1; static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
#endif #endif
} }
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> && if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
......
...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( ...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
......
...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances ...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
......
...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances ...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
......
...@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment