Unverified Commit a73ab0d8 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #10 from ROCmSoftwarePlatform/lwpck-702

Merge down from public repo.
parents aaf4defa 55a6b4e3
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/docs/.sphinx" # Location of package manifests
open-pull-requests-limit: 10
schedule:
interval: "daily"
FROM ubuntu:20.04
ARG ROCMVERSION=5.3
ARG compiler_version="release"
ARG ROCMVERSION=5.6
ARG compiler_version=""
ARG compiler_commit=""
RUN set -xe
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins
RUN useradd -rm -d /home/manitera -s /bin/bash -u 1002 manitera
# Add rocm repository
RUN apt-get update
RUN apt-get install -y wget gnupg
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
RUN apt-get install -y wget gnupg curl
RUN --mount=type=ssh if [ "$ROCMVERSION" != "5.6"]; then \
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"; \
else sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amd-nonfree-radeon_20.04-1_all.deb" && \
apt update && apt-get install -y ./amd-nonfree-radeon_20.04-1_all.deb && \
amdgpu-repo --amdgpu-build=1567752 --rocm-build=compute-rocm-dkms-no-npi-hipclang/11914 && \
DEBIAN_FRONTEND=noninteractive amdgpu-install -y --usecase=rocm ; \
fi
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
# Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \
build-essential \
ccache \
cmake-data \
cmake \
curl \
git \
hip-rocclr \
jq \
......@@ -45,6 +49,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rocm-device-libs \
rocm-cmake \
vim \
nano \
zlib1g-dev \
openssh-server \
clang-format-10 \
......@@ -52,6 +57,17 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean && \
rm -rf /var/lib/apt/lists/*
#Install latest version of cmake
RUN apt purge --auto-remove -y cmake
RUN apt update
RUN apt install -y software-properties-common lsb-release
RUN apt clean all
RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null
RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main"
RUN apt install -y kitware-archive-keyring
RUN rm /etc/apt/trusted.gpg.d/kitware.gpg
RUN apt install -y cmake
# Setup ubsan environment to printstacktrace
RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
ENV UBSAN_OPTIONS=print_stacktrace=1
......@@ -87,12 +103,7 @@ ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'"
RUN sh -c "echo compiler commit = '$compiler_commit'"
RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ]; then \
sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/hip/bin/hipcc.pl && \
sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/bin/hipcc.pl; \
fi
RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" = "" ]; then \
RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \
......@@ -100,7 +111,7 @@ RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_com
else echo "using the release compiler"; \
fi
RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" != "" ]; then \
RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \
......
......@@ -19,12 +19,33 @@ def runShell(String command){
def getDockerImageName(){
def img
if (params.COMPILER_COMMIT == ""){
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
if (params.ROCMVERSION != "5.5" && params.ROCMVERSION != "5.6"){
if (params.COMPILER_VERSION == "") {
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
}
else{
if (params.COMPILER_COMMIT == ""){
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
}
else{
def commit = "${params.COMPILER_COMMIT}"[0..6]
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
}
}
}
else{
def commit = "${params.COMPILER_COMMIT}"[0..6]
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
if (params.COMPILER_VERSION == "") {
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}"
}
else{
if (params.COMPILER_COMMIT == ""){
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
}
else{
def commit = "${params.COMPILER_COMMIT}"[0..6]
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
}
}
}
return img
}
......@@ -49,11 +70,11 @@ def build_compiler(){
compiler = '/opt/rocm/bin/hipcc'
}
else{
if (params.COMPILER_VERSION == "release"){
compiler = "/opt/rocm/llvm/bin/clang++"
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
compiler = "/llvm-project/build/bin/clang++"
}
else{
compiler = "/llvm-project/build/bin/clang++"
compiler = "/opt/rocm/llvm/bin/clang++"
}
}
return compiler
......@@ -232,7 +253,7 @@ def buildHipClangJob(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION != "release"){
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
......@@ -287,7 +308,7 @@ def runCKProfiler(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION != "release"){
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
......@@ -420,7 +441,7 @@ def Build_CK(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION != "release"){
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
......@@ -586,16 +607,16 @@ pipeline {
description: "Force building docker image (default: false), set to true if docker image needs to be updated.")
string(
name: 'ROCMVERSION',
defaultValue: '5.4.3',
description: 'Specify which ROCM version to use: 5.4.3 (default).')
defaultValue: '5.6',
description: 'Specify which ROCM version to use: 5.6 (default).')
string(
name: 'COMPILER_VERSION',
defaultValue: 'amd-stg-open',
description: 'Specify which version of compiler to use: ck-9110, release, or amd-stg-open (default).')
defaultValue: '',
description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).')
string(
name: 'COMPILER_COMMIT',
defaultValue: '5541927df00eabd6a110180170eca7785d436ee3',
description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.')
defaultValue: '',
description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.')
string(
name: 'BUILD_COMPILER',
defaultValue: 'hipcc',
......
......@@ -73,7 +73,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
......@@ -203,4 +203,4 @@ int main(int argc, char* argv[])
}
return 0;
}
\ No newline at end of file
}
......@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
......@@ -206,4 +206,4 @@ int main(int argc, char* argv[])
}
return 0;
}
\ No newline at end of file
}
......@@ -69,7 +69,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp =
......@@ -196,4 +196,4 @@ int main(int argc, char* argv[])
}
return 0;
}
\ No newline at end of file
}
add_executable(client_groupnorm_swish groupnorm_swish.cpp)
target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using XDataType = ck::half_t;
using GammaDataType = float;
using BetaDataType = float;
using YDataType = ck::half_t;
using ComputeDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
ck::index_t N = 32;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t G = 64;
ck::index_t C = 128;
std::size_t xy_size = N * H * W * G * C;
std::size_t gamma_beta_size = G * C;
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
Swish,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte =
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
......@@ -21,6 +21,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-comma
-Wno-old-style-cast
-Wno-deprecated
-Wno-unsafe-buffer-usage
)
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
......
git+https://github.com/RadeonOpenCompute/rocm-docs-core.git
rocm-docs-core==0.2.0
sphinxcontrib-bibtex==2.5.0
......@@ -2,9 +2,9 @@
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile requirements.in
# pip-compile .sphinx/requirements.in
#
accessible-pygments==0.0.4
accessible-pygments==0.0.3
# via pydata-sphinx-theme
alabaster==0.7.13
# via sphinx
......@@ -20,7 +20,7 @@ babel==2.12.1
# sphinx
backcall==0.2.0
# via ipython
beautifulsoup4==4.12.0
beautifulsoup4==4.11.2
# via pydata-sphinx-theme
breathe==4.34.0
# via rocm-docs-core
......@@ -34,7 +34,7 @@ click==8.1.3
# via
# jupyter-cache
# sphinx-external-toc
comm==0.1.3
comm==0.1.2
# via ipykernel
debugpy==1.6.6
# via ipykernel
......@@ -65,13 +65,11 @@ idna==3.4
# via requests
imagesize==1.4.1
# via sphinx
importlib-metadata==6.1.0
importlib-metadata==6.0.0
# via
# jupyter-cache
# myst-nb
importlib-resources==5.10.4
# via rocm-docs-core
ipykernel==6.22.0
ipykernel==6.21.3
# via myst-nb
ipython==8.11.0
# via
......@@ -87,7 +85,7 @@ jsonschema==4.17.3
# via nbformat
jupyter-cache==0.5.0
# via myst-nb
jupyter-client==8.1.0
jupyter-client==8.0.3
# via
# ipykernel
# nbclient
......@@ -124,7 +122,7 @@ nbclient==0.5.13
# via
# jupyter-cache
# myst-nb
nbformat==5.8.0
nbformat==5.7.3
# via
# jupyter-cache
# myst-nb
......@@ -187,7 +185,7 @@ pyyaml==6.0
# myst-parser
# pybtex
# sphinx-external-toc
pyzmq==25.0.2
pyzmq==25.0.1
# via
# ipykernel
# jupyter-client
......@@ -195,8 +193,8 @@ requests==2.28.2
# via
# pygithub
# sphinx
rocm-docs-core @ git+https://github.com/RadeonOpenCompute/rocm-docs-core.git
# via -r requirements.in
rocm-docs-core==0.2.0
# via -r .sphinx/requirements.in
six==1.16.0
# via
# asttokens
......@@ -235,9 +233,7 @@ sphinx-notfound-page==0.8.3
sphinxcontrib-applehelp==1.0.4
# via sphinx
sphinxcontrib-bibtex==2.5.0
# via
# -r requirements.in
# rocm-docs-core
# via -r .sphinx/requirements.in
sphinxcontrib-devhelp==1.0.2
# via sphinx
sphinxcontrib-htmlhelp==2.0.1
......@@ -248,7 +244,7 @@ sphinxcontrib-qthelp==1.0.3
# via sphinx
sphinxcontrib-serializinghtml==1.1.5
# via sphinx
sqlalchemy==1.4.47
sqlalchemy==1.4.46
# via jupyter-cache
stack-data==0.6.2
# via ipython
......
add_example_executable(example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp)
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
struct YElementOp
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
ck::is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
T a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a;
};
};
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
YElementOp,
Rank,
NumReduceDim,
1024, // BlockSize
1, // ClusterM
1024, // ClusterK
1, // SliceM
32, // SliceK
1, // SrcVecDim (0=M, 1=K)
2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector
2>; // OutScalarPerVector
#include "run_groupnorm_example.inc"
int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
YElementOp,
Rank,
NumReduceDim,
1024, // BlockSize
1, // ClusterM
1024, // ClusterK
1, // SliceM
32, // SliceK
1, // SrcVecDim (0=M, 1=K)
2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector
2>; // OutScalarPerVector
#include "run_groupnorm_example.inc"
int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
struct YElementOp
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
ck::is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
T a;
#pragma once
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a;
};
};
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
YElementOp,
Rank,
NumReduceDim,
1024, // BlockSize
1, // ClusterM
1024, // ClusterK
1, // SliceM
32, // SliceK
1, // SrcVecDim (0=M, 1=K)
2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector
2>; // OutScalarPerVector
int main(int argc, char* argv[])
int run_groupnorm_example(int argc, char* argv[])
{
ck::index_t N = 2;
ck::index_t H = 32;
ck::index_t W = 32;
ck::index_t G = 32;
ck::index_t C = 30;
ck::index_t N = 32;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t G = 64;
ck::index_t C = 128;
if(argc == 1)
{
......
......@@ -172,6 +172,11 @@
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#endif
namespace ck {
enum struct InMemoryDataOperationEnum
......
......@@ -73,18 +73,157 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto K1Number = Number<K1>{};
static auto
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
{
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
{
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
MPerBlock,
NPerBlock,
K0PerBlock,
......@@ -114,64 +253,236 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
using Argument = typename GridwiseGemm::Argument;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t k_batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
k_batch_{k_batch}
{
int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_);
a_grid_desc_kbatch_k0_m_k1_ =
DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1(
M, K, StrideA, k_batch_, KPad);
b_grid_desc_kbatch_k0_n_k1_ =
DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1(
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t k_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmXdlSplitKCShuffle::Argument;
void Print(const Argument& karg) { karg.Print(); }
void Print(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
Print(karg);
Print(arg);
}
const auto kbatch = karg.k_batch;
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(karg))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting");
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg);
const auto K0 = karg.K0;
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(kbatch > 1)
hipGetErrorString(
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set>;
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd>;
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
true>;
Run(kernel);
}
......@@ -180,19 +491,37 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
{
if(kbatch == 1)
{
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set>;
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd>;
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
false>;
Run(kernel);
}
......@@ -215,9 +544,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
return true;
}
static bool IsSupportedArgument(const Argument& karg)
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(karg);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
......@@ -235,9 +567,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch)
{
return Argument{p_a,
......@@ -249,10 +581,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateK0(K, KBatch),
1,
1,
a_element_op,
b_element_op,
c_element_op,
KBatch};
}
......@@ -268,9 +601,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -282,10 +615,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateK0(K, KBatch),
1,
1,
a_element_op,
b_element_op,
c_element_op,
KBatch);
}
......@@ -296,7 +630,31 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdlSplitKCShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
......
......@@ -316,8 +316,6 @@ struct Sigmoid
y = 1 / (ck::type_convert<T>(1) + exp(-x));
};
int32_t divider_ = 1;
};
struct TanH
......@@ -333,6 +331,23 @@ struct TanH
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
};
float beta_ = 1.0f;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment