Commit 9c0811f3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents ded0d83d 3528a523
......@@ -23,6 +23,11 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3.0.1-20.04-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3.0.1-20.04-1_all.deb && \
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3.0.1 rel-5 > /etc/apt/sources.list.d/rocm-build.list' && \
amdgpu-repo --amdgpu-build=2033700; \
fi
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
......@@ -130,6 +135,8 @@ ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'"
RUN sh -c "echo compiler commit = '$compiler_commit'"
ARG DISABLE_CACHE=0
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
......
......@@ -94,13 +94,21 @@ def getDockerImage(Map conf=[:]){
env.DOCKER_BUILDKIT=1
def prefixpath = conf.get("prefixpath", "/opt/rocm")
def no_cache = conf.get("no_cache", false)
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
if(no_cache)
{
dockerArgs = dockerArgs + " --no-cache "
}
echo "Docker Args: ${dockerArgs}"
def image = getDockerImageName()
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
//Check if image exists
def retimage
try
......@@ -124,8 +132,10 @@ def buildDocker(install_prefix){
checkout scm
def image_name = getDockerImageName()
echo "Building Docker for ${image_name}"
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerArgs = dockerArgs + " --no-cache "
}
echo "Build Args: ${dockerArgs}"
try{
if(params.BUILD_DOCKER){
......@@ -259,6 +269,7 @@ def cmake_build(Map conf=[:]){
""")
sh cmd3
}
// reduce parallelism when compiling, clang uses too much memory
def nt = nthreads()
def cmd
......@@ -273,7 +284,7 @@ def cmake_build(Map conf=[:]){
}
else{
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}")
build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}")
}
cmd = conf.get("cmd", """
${setup_cmd}
......@@ -292,8 +303,8 @@ def cmake_build(Map conf=[:]){
dir("build"){
//build CK
sh cmd
//run tests
if(!setup_args.contains("NO_CK_BUILD")){
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts "ck_build_trace.json"
......@@ -330,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
checkout scm
def image = getDockerImageName()
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
......@@ -512,7 +531,16 @@ def Build_CK(Map conf=[:]){
env.DOCKER_BUILDKIT=1
checkout scm
def image = getDockerImageName()
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
......@@ -524,6 +552,9 @@ def Build_CK(Map conf=[:]){
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
if(params.BUILD_LEGACY_OS){
dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
}
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
......@@ -703,11 +734,12 @@ def process_results(Map conf=[:]){
}
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2; RUN_CK_TILE_TESTS=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : ""
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
0 13 * * * % BUILD_LEGACY_OS=true ''' : ""
pipeline {
agent none
......@@ -775,9 +807,13 @@ pipeline {
defaultValue: false,
description: "Run the grouped conv large cases tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_TESTS",
name: "RUN_CK_TILE_FMHA_TESTS",
defaultValue: false,
description: "Run the ck_tile tests (default: OFF)")
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_GEMM_TESTS",
defaultValue: false,
description: "Run the ck_tile GEMM tests (default: OFF)")
booleanParam(
name: "BUILD_INSTANCES_ONLY",
defaultValue: false,
......@@ -790,6 +826,10 @@ pipeline {
name: "NINJA_BUILD_TRACE",
defaultValue: false,
description: "Generate a ninja build trace (default: OFF)")
booleanParam(
name: "BUILD_LEGACY_OS",
defaultValue: false,
description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)")
}
environment{
dbuser = "${dbuser}"
......@@ -894,15 +934,15 @@ pipeline {
}
}
}
stage("Run CK_TILE Tests")
stage("Run CK_TILE_FMHA Tests")
{
parallel
{
stage("Run CK_TILE Tests on gfx90a")
stage("Run CK_TILE_FMHA Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_TESTS.toBoolean() }
expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
......@@ -917,11 +957,11 @@ pipeline {
cleanWs()
}
}
stage("Run CK_TILE Tests on gfx942")
stage("Run CK_TILE_FMHA Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_TESTS.toBoolean() }
expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
......@@ -938,21 +978,103 @@ pipeline {
}
}
}
stage("Run CK_TILE_GEMM Tests")
{
parallel
{
stage("Run CK_TILE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 tile_example_gemm_basic && \
cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run CK_TILE_GEMM Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_gemm_basic && \
cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Build CK and run Tests")
{
parallel
{
stage("Build CK with RHEL8")
{
when {
beforeAgent true
expression { params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
setup_args = """ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args = " "
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
cleanWs()
}
}
stage("Build CK with SLES15")
{
when {
beforeAgent true
expression { params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_sles15_rocm6.3"
setup_args = """ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args = " "
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
cleanWs()
}
}
stage("Build CK for all gfx9 targets")
{
when {
beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() }
expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
-DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \
-DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
......@@ -969,7 +1091,7 @@ pipeline {
{
when {
beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() }
expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
......@@ -989,7 +1111,7 @@ pipeline {
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
......@@ -1009,7 +1131,7 @@ pipeline {
{
when {
beforeAgent true
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() }
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
......@@ -1028,7 +1150,7 @@ pipeline {
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx1030") }
environment{
......@@ -1048,7 +1170,7 @@ pipeline {
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx1101") }
environment{
......@@ -1068,7 +1190,7 @@ pipeline {
{
when {
beforeAgent true
expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx1201") }
environment{
......@@ -1095,7 +1217,7 @@ pipeline {
{
when {
beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
options { retry(1) }
agent{ label rocmnode("gfx90a")}
......@@ -1116,7 +1238,7 @@ pipeline {
stage("Process results"){
when {
beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent { label 'mici' }
steps{
......
rocm-docs-core==1.7.2
sphinxcontrib-bibtex==2.6.2
rocm-docs-core==1.8.1
sphinxcontrib-bibtex==2.6.3
......@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.7.2
rocm-docs-core==1.8.1
# via -r requirements.in
six==1.16.0
# via pybtex
......@@ -137,7 +137,7 @@ sphinx-notfound-page==1.0.3
# via rocm-docs-core
sphinxcontrib-applehelp==2.0.0
# via sphinx
sphinxcontrib-bibtex==2.6.2
sphinxcontrib-bibtex==2.6.3
# via -r requirements.in
sphinxcontrib-devhelp==2.0.0
# via sphinx
......
......@@ -305,6 +305,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
#endif
}
else
{
// When the Problem Type and Problem Size does not fit.
std::cerr << gemm.GetTypeString() << ": the instance does not support the problem config."
<< std::endl;
return true;
}
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
......
......@@ -161,18 +161,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
#if 0
printf("B matrix:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
printf("%02x ", *(reinterpret_cast<uint8_t*>(&b_k_n(ik,in))));
if(ik%8==7) printf("|");
}
printf("\n");
}
#endif
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......@@ -272,7 +260,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if(config.time_kernel)
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
......
set(CMAKE_BUILD_TYPE Debug)
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
\ No newline at end of file
# GEMM Matrix Multiplication
This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_gemm_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_basic`
## example
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-k k dimension (default:64)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size")
.insert("m", "1024", "m dimension")
.insert("n", "2048", "n dimension")
.insert("k", "64", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("e", "1e-5", "Absolute error tolerance")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename LayoutA,
typename LayoutB,
typename LayoutC,
typename PipelineProblem,
typename GemmPipeline,
typename GemmShape>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel =
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.epsilon,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
template <typename DataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename PipelineProblem,
typename GemmPipeline,
typename GemmShape>
float invoke_gemm(ck_tile::DeviceMem& a_buf,
ck_tile::DeviceMem& b_buf,
ck_tile::DeviceMem& c_buf,
const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
if(data_type != DataTypeTraits<DataType>::name)
{
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
<< data_type << std::endl;
return -1; // Or handle the error appropriately
}
float epsilon = arg_parser.get_float("e");
ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args;
args.p_a = a_buf.GetDeviceBuffer();
args.p_b = b_buf.GetDeviceBuffer();
args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon;
args.kbatch = batch_size;
args.M = M;
args.N = N;
args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if(stride_a != 0)
{
args.stride_A = stride_a;
}
else
{
args.stride_A = [&]() {
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return K;
}
}();
}
if(stride_b != 0)
{
args.stride_B = stride_b;
}
else
{
args.stride_B = [&]() {
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::RowMajor>)
{
return N;
}
else
{
return K;
}
}();
}
if(stride_c != 0)
{
args.stride_C = stride_c;
}
else
{
args.stride_C = [&]() {
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return N;
}
}();
}
float ave_time = gemm_calc<LayoutA, LayoutB, LayoutC, PipelineProblem, GemmPipeline, GemmShape>(
args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
<< " is: \n";
std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
<< std::flush;
return ave_time;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor;
using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor;
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
// host verify
std::vector<int> a_dimensions =
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> b_dimensions =
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> c_dimensions =
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<ADataType> a_host(a_dimensions);
ck_tile::HostTensor<BDataType> b_host(b_dimensions);
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
kPadA,
kPadB,
kPadC>;
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout,
CodegenPipelineProblem,
CodegenGemmPipeline,
CodegenGemmShape>(a_buf, b_buf, c_buf, arg_parser);
c_buf.FromDevice(c_host_dev.data());
bool pass_cpu = true;
if(arg_parser.get_int("v") == 1)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(a_host, b_host, c_host_ref);
pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The CPU veification result is:" << (pass_cpu ? "correct" : "fail")
<< std::flush;
}
bool pass_gpu = true;
if(arg_parser.get_int("v") == 2)
{
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
if(stride_a == 0)
{
if constexpr(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_a = M;
}
else
{
stride_a = K;
}
}
if(stride_b == 0)
{
if constexpr(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b = N;
}
else
{
stride_b = K;
}
}
if(stride_c == 0)
{
if constexpr(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_c = M;
}
else
{
stride_c = N;
}
}
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data());
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush;
}
std::cout << std::endl << std::flush;
return !pass_gpu;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t; // type convert
// ToDo: Add more bias config to support different categories of GEMM.
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct gemm_basic_args
{
const void* p_a;
const void* p_b;
void* p_c;
float epsilon;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
// host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_gemm executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
# get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type
export branch=$2
echo 'Branch name: ' $branch
export host_name=$3
echo 'Host name: ' $host_name
export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch
# run verification tests
example/ck_tile/03_gemm/script/smoke_test.sh
# We do not have a performance benchmark for gemm yet. Will add it in the future.
\ No newline at end of file
#!/bin/bash
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
}
set -x
run_fp16_tests
set +x
\ No newline at end of file
......@@ -4,3 +4,4 @@ include_directories(AFTER
add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm)
......@@ -446,7 +446,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
});
});
});
__builtin_amdgcn_sched_barrier(0);
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
SparseXdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0,
"MPerBlock must be divisible by MPerXDL * MRepeat");
static_assert(NPerBlock % (NPerXDL * NRepeat) == 0,
"NPerBlock must be divisible by NPerXDL * NRepeat");
static_assert(
KPack % (16 * sizeof(ComputeTypeA)) == 0,
"KPack must be divisbile by number of elements processed in single smfmac instruction");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
// Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4
// structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in
// smfmac instruction
template <typename AThreadBuf, typename IdxBuf, int32_t num_elems>
__device__ void SetIdxSqueezeA(AThreadBuf& a_thread_buf, IdxBuf& idx_buf)
{
static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000};
static constexpr int32_t processed_elems = 16 / sizeof(ComputeTypeA);
static_for<0, num_elems, processed_elems>{}([&](auto i) {
constexpr int idx_reg_num = i / (16 * sizeof(ComputeTypeA));
constexpr int idx_reg_part = (i % 32) / processed_elems;
vector_type<ComputeTypeA, processed_elems> a_thread_vec;
static_for<0, processed_elems, 1>{}([&](auto j) {
a_thread_vec.template AsType<ComputeTypeA>()(j) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i + j))>{}];
});
uint8_t idx = 0b11101110; // set to last 2 elems for both 4-elems subgroups by default
for(int j = 0; j < processed_elems; j += 4)
{
int32_t a_pos = idx_reg_part * processed_elems + j;
int32_t nonzero_pos = 0;
ComputeTypeA nonzero_elems[2] = {a_thread_vec[j + 2], a_thread_vec[j + 3]};
for(int k = 0; k < 3; k += 1)
{
if(a_thread_vec[j + k] != 0.0f)
{
nonzero_elems[nonzero_pos] = a_thread_vec[j + k];
idx &= ~bit_clear_masks[j / 2 + nonzero_pos];
idx |= k << 2 * (j / 2 + nonzero_pos);
++nonzero_pos;
}
}
a_thread_vec[j / 2] = nonzero_elems[0];
a_thread_vec[j / 2 + 1] = nonzero_elems[1];
}
IdxBuf[idx_reg_num].AsType<int8x4_t>()[Number<idx_reg_part>{}] = idx;
static_for<0, processed_elems / 2, 1>{}([&](auto j) {
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, i / 2 + j))>{}] = a_thread_vec[j];
});
});
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static constexpr int32_t elems_per_idx = 16 * sizeof(ComputeTypeA);
auto idx_buf = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
(a_thread_desc_.GetElementSpaceSize() + elems_per_idx - 1) / elems_per_idx);
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
SetIdxSqueezeA(a_thread_buf, idx_buf, a_thread_desc_.GetElementSpaceSize());
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
// a_thread_vec is smaller because it's structurally sparse 2:4
vector_type<ComputeTypeA, KPack / 2> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<int32_t, KPack / elems_per_idx> idx_vec;
static_for<0, KPack / 2, 1>{}([&](auto i) {
a_thread_vec.template AsType<ComputeTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, k / 2 + i))>{}];
});
static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<ComputeTypeB>()(2 * i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
static_for<0, KPack / elems_per_idx, 1>{}([&](auto i) {
idx_vec.template AsType<int32_t>()(i) = idx_buf[k / elems_per_idx + i];
});
// A is smaller because it's structurally sparse 2:4
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / 2>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_idx = typename vector_type<int32_t, 1>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
idx_vec.template AsType<mfma_input_type_idx>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// In and Din = [N, C, Hi, Wi]
// Out and Dout = [N, C, Ho, Wo]
// Out = AvgPool2dFwd(In)
// Din = AvgPool2dBwd(Dout)
// Pooling dimension = H, W
template <typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
ck::index_t BlockSize,
ck::index_t MThreadClusterSize,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DeviceAvgPool2dBwd_NHWC_NHWC : public DeviceAvgPoolBwd<2,
DOutDataType,
DInDataType,
tensor_layout::convolution::NHWC,
tensor_layout::convolution::NHWC>
{
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto
Make2DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
const std::vector<ck::index_t>& din_n_c_wos_length,
const std::vector<ck::index_t>& dout_n_c_wos_strides,
const std::vector<ck::index_t>& din_n_c_wos_strides,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads,
const std::vector<ck::index_t>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t N = dout_n_c_wos_lengths[0];
const index_t C = dout_n_c_wos_lengths[1];
const index_t Ho = dout_n_c_wos_lengths[2];
const index_t Wo = dout_n_c_wos_lengths[3];
const index_t Hi = din_n_c_wos_length[2];
const index_t Wi = din_n_c_wos_length[3];
const index_t Y = window_lengths[0];
const index_t X = window_lengths[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = window_strides[0];
const index_t ConvStrideW = window_strides[1];
const index_t ConvDilationH = window_dilations[0];
const index_t ConvDilationW = window_dilations[1];
const index_t Ni_stride = dout_n_c_wos_strides[0];
const index_t Ci_stride = dout_n_c_wos_strides[1];
const index_t Ho_stride = dout_n_c_wos_strides[2];
const index_t Wo_stride = dout_n_c_wos_strides[3];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on Tildes that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// ReduceK is different for each Reduce
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// Problem size of reduction kernel
const index_t MRaw = N * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = YDotSlice * XDotSlice;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride));
// Out[ReduceM, ReduceK]
const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)),
make_merge_transform(make_tuple(YDotSlice, XDotSlice))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
out_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// In[ReduceM]
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(din_n_c_wos_strides[0],
din_n_c_wos_strides[2],
din_n_c_wos_strides[3],
din_n_c_wos_strides[1]));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto in_grid_desc_reducem =
transform_tensor_descriptor(in_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
}
using DoutDinGridDesc = decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0},
{0, 0, 0, 0},
{0, 0, 0, 0},
{0, 0, 0, 0},
{0, 0},
{0, 0},
{0, 0},
{0, 0},
{0, 0},
{0, 0}));
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
// FIXME
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide;
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
reduce::Add,
PassThrough,
Div,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
OutSrcInDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
struct Argument : public BaseArgument
{
Argument(const DOutDataType* p_dout,
DInDataType* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout},
p_din_grid_{p_din},
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
din_n_c_wos_length_{din_n_c_wos_length},
dout_n_c_wos_strides_{dout_n_c_wos_strides},
din_n_c_wos_strides_{din_n_c_wos_strides},
num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1]}
{
std::vector<ck::index_t> Tildes(NDimSpatial);
for(int i = 0; i < NDimSpatial; ++i)
{
int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
Tildes[i] = window_strides[i] / GcdStrideDilation;
num_reduce_ *= Tildes[i];
}
for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde)
{
const auto YDotSlice =
math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]);
const auto XDotSlice =
math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto dout_din_grid_desc =
Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
}
}
}
const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_;
std::vector<ck::index_t> dout_n_c_wos_lengths_;
std::vector<ck::index_t> din_n_c_wos_length_;
std::vector<ck::index_t> dout_n_c_wos_strides_;
std::vector<ck::index_t> din_n_c_wos_strides_;
int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
Div div_element_op_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++)
{
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
false,
false,
false, // don't have index input
DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
PassThrough,
Div>;
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
const index_t grid_size = (M / M_BlockTileSize);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.dout_grid_desc_m_k_container_[i],
arg.din_grid_desc_m_container_[i],
PassThrough{},
arg.div_element_op_,
float(1),
arg.p_dout_grid_,
nullptr,
float(0),
arg.p_din_grid_,
nullptr);
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr index_t Rank = NDimSpatial + 2;
int doutFastestDim = -1;
int dinFastestDim = -1;
for(int i = 0; i < Rank; ++i)
{
if(arg.dout_n_c_wos_strides_[i] == 1)
doutFastestDim = i;
if(arg.din_n_c_wos_strides_[i] == 1)
dinFastestDim = i;
}
if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1))
{
return false;
}
if(doutFastestDim == -1 || dinFastestDim == -1)
{
if constexpr(InSrcOutDstVectorSize != 1)
return false;
}
else
{
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
return false;
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
{
constexpr index_t Rank = NDimSpatial + 2;
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
{
throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or "
"[dout|din]_n_c_wos_lengths is not equal to Rank");
}
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
{
throw std::runtime_error(
"dimension of [window_lengths, window_strides, window_dilations, input_left_pads, "
"input_right_pads] is not equal to Rank");
}
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din),
dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceAvgPool2dBwd<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -171,6 +171,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
......@@ -179,11 +189,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType),
DsSize);
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
......
......@@ -155,11 +155,19 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType));
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
......
......@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
......@@ -22,7 +23,6 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
KPerBlock / K1Number,
ConvBackwardWeightSpecialization>{};
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<InLayout,
WeiLayout,
OutLayout,
NDimSpatial,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
......@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch)[I2];
}
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& HiStride = g_n_c_wis_strides[3];
const index_t& WiStride = g_n_c_wis_strides[4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& DiStride = g_n_c_wis_strides[3];
const index_t& HiStride = g_n_c_wis_strides[4];
const index_t& WiStride = g_n_c_wis_strides[5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
using InputTransposeDescType =
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
using OutputTransposeDescType =
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
using NGCHWTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
......@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1>;
using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<InputTransposeDescType>,
Tuple<OutputTransposeDescType>,
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
......@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
b_g_n_c_wis_strides;
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
a_g_n_k_wos_strides;
// NGKHW - transpose needed
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
b_g_n_c_wis_strides_transposed[I2] = I1;
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
a_g_n_k_wos_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] =
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
}
else if constexpr(NDimSpatial == 3)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
input_spatial_lengths_[I2] * Conv_G_ *
Conv_K_;
a_g_n_k_wos_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
}
}
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
const auto descs =
conv_to_gemm_transformer_v2
......@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
a_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
b_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
b_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
......@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_;
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
......@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<BDataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
......
......@@ -15,9 +15,11 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
......@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
// NGCHW is not supported for multiAB
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
!(isMultiA || isMultiB));
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
......@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
EDataType,
NumGroupsToMerge>;
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<ALayout,
BLayout,
ELayout,
NDimSpatial,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I1,
I0>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
Tuple<const EDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
// Argument
struct Argument : public BaseArgument
......@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
num_group_{a_g_n_c_wis_lengths_[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
......@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
cde_element_op_{cde_element_op}
{
// A/B/E Batch Stride
if constexpr(isMultiA || isMultiB)
......@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideA_(i) =
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
......@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_.BatchStrideA_(i) =
a_g_n_c_wis_strides[1] * conv_N_per_block_;
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as);
compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_c_wis_strides[1] * conv_N_per_block_;
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideB_(i) =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
......@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
{
compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
// p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
......@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge;
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
ds_g_n_k_wos_strides_[i],
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_};
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_groups_.BatchStrideE_ =
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
// populate desc for Ds/E
if constexpr(isMultiA || isMultiB)
......@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_);
}
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
a_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
}
void Print() const
......@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
// tensor descriptors for problem definiton
index_t num_group_;
......@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
......@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
......@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
......@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
......@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
p_a_grid, // Pass just A descriptor instead of tuple
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
arg.p_ds_grid_,
arg.p_e_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
......@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.p_as_grid_.At(I0)),
make_tuple(p_a_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
......@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
{
return false;
}
......@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;
......@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
if(!valid)
{
return false;
......@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
......@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment