Commit 7e493730 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents b89a88b5 40942b90
cff-version: 1.2.0
title: Composable Kernel
message: If you use this software, please cite using the following metadata.
type: software
authors:
- given-names: Chao
family-names: Liu
email: chao.liu2@amd.com
affiliation: AMD
- given-names: Jing
family-names: Zhang
email: jing.zhang3@amd.com
affiliation: AMD
- given-names: Letao
family-names: Qin
email: letao.qin@amd.com
affiliation: AMD
- given-names: Qianfeng
family-names: Zhang
email: qianfeng.zhang@amd.com
affiliation: AMD
- given-names: Liang
family-names: Huang
email: carlus.huang@amd.com
affiliation: AMD
- given-names: Shaojie
family-names: Wang
email: shaojie.wang@amd.com
affiliation: AMD
- given-names: Anthony
family-names: Chang
email: antc@amd.com
affiliation: AMD
- given-names: Chunyu
family-names: Lai
email: chunyu.lai@amd.com
affiliation: AMD
- given-names: Illia
family-names: Silin
email: illia.silin@amd.com
affiliation: AMD
- given-names: Adam
family-names: Osewski
email: adam.osewski@amd.com
affiliation: AMD
- given-names: Poyen
family-names: Chen
email: poyen.chen@amd.com
affiliation: AMD
- given-names: Rosty
family-names: Geyyer
email: rosty.geyyer@amd.com
affiliation: AMD
- given-names: Hanwen
family-names: Chen
- given-names: Tejash
family-names: Shah
- given-names: Xiaoyan
family-names: Zhou
- given-names: Jianfeng
family-names: Yan
repository-code: 'https://github.com/ROCmSoftwarePlatform/composable_kernel'
abstract: Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for Machine Learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel progarmming languages, like HIP C++.
keywords:
- 'CK, Composable Kernel, Tensor Coordinate Transformation'
license: MIT
license-url: https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE
# Composable Kernel Developers and Contributors
This is the list of developers and contributors to Composable Kernel library
## Developers
[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2022
[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022
[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), 2022
Hanwen Chang, 2019-2021,
Tejash Shah, 2019-2020
Xiaoyan Zhou, 2020
[Jianfeng Yan](https://github.com/j4yan), 2021-2022
## Product Manager
[Jun Liu](https://github.com/junliume)
## Contributors
[Dan Yao](https://github.com/danyao12), [Guangzhao Lu](https://github.com/guangzlu), [Raman Jana](https://github.com/ramjana), [Jehandad Khan](https://github.com/JehandadKhan), [Wen-Heng (Jack) Chung](https://github.com/whchung)
## Acknowledgement
CK team works closely with Meta [AITemplate](https://github.com/facebookincubator/AITemplate) team ([Bing Xu](https://github.com/antinucleon), [Hao Lu](https://github.com/hlu1), [Ying Zhang](https://github.com/ipiszy), etc). Most of the lucrative graph optimization opportunities in ML models were identified by AITemplate team, and we also co-designed many high performance fused kernels for AMD GPUs. Without this collaboration, CK would not reach its current potential.
@PACKAGE_INIT@ @PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations host_tensor) set(_composable_kernel_supported_components device_operations utility)
foreach(_comp ${composable_kernel_FIND_COMPONENTS}) foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components) if(NOT _comp IN_LIST _composable_kernel_supported_components)
......
FROM ubuntu:20.04 FROM ubuntu:20.04
ARG ROCMVERSION=5.2.3 ARG ROCMVERSION=5.3
ARG compiler_version ARG compiler_version="release"
ARG compiler_commit=""
RUN set -xe RUN set -xe
...@@ -12,12 +13,13 @@ RUN apt-get install -y wget gnupg ...@@ -12,12 +13,13 @@ RUN apt-get install -y wget gnupg
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
build-essential \ build-essential \
ccache \
cmake-data \ cmake-data \
cmake \ cmake \
curl \ curl \
...@@ -68,7 +70,6 @@ ENV UBSAN_OPTIONS=print_stacktrace=1 ...@@ -68,7 +70,6 @@ ENV UBSAN_OPTIONS=print_stacktrace=1
ENV LC_ALL=C.UTF-8 ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8 ENV LANG=C.UTF-8
ADD dev-requirements.txt dev-requirements.txt
RUN groupadd -f render RUN groupadd -f render
# Install the new rocm-cmake version # Install the new rocm-cmake version
...@@ -79,9 +80,16 @@ RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && ...@@ -79,9 +80,16 @@ RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git &&
WORKDIR / WORKDIR /
ENV compiler_version=$compiler_version ENV compiler_version=$compiler_version
ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler version = '$compiler_version'"
RUN sh -c "echo compiler commit = '$compiler_commit'"
RUN --mount=type=ssh if [ "$compiler_version" != "release" ]; then \ RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ]; then \
sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/hip/bin/hipcc.pl && \
sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/bin/hipcc.pl; \
fi
RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \ cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \
...@@ -89,5 +97,14 @@ RUN --mount=type=ssh if [ "$compiler_version" != "release" ]; then \ ...@@ -89,5 +97,14 @@ RUN --mount=type=ssh if [ "$compiler_version" != "release" ]; then \
else echo "using the release compiler"; \ else echo "using the release compiler"; \
fi fi
RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi
#ENV HIP_CLANG_PATH='/llvm-project/build/bin' #ENV HIP_CLANG_PATH='/llvm-project/build/bin'
#RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" #RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'"
This diff is collapsed.
## Docker script # Composable Kernel
## Methodology
Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++.
CK utilizes two concepts to achieve performance portability and code maintainability:
* A tile-based programming model
* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation".
![ALT](/doc/image/ck_component.png "CK Components")
## Code Structure
Current CK library are structured into 4 layers:
* "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer
* "Client API" layer
![ALT](/doc/image/ck_layer.png "CK Layers")
## Contributors
The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md)
## Citation
If you use CK, please use following citations:
* CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???)
* [CITATION.cff](/CITATION.cff)
## License
CK is released under the MIT license. [License File](/LICENSE)
# Build CK
## Build docker image
```bash
DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile .
```
## Launch docker
```bash ```bash
docker run \ docker run \
-it \ -it \
...@@ -6,47 +45,38 @@ docker run \ ...@@ -6,47 +45,38 @@ docker run \
--group-add sudo \ --group-add sudo \
-w /root/workspace \ -w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm5.1-tf2.6-dev \ ck:latest \
/bin/bash /bin/bash
``` ```
# Install newer version of rocm-cmake ## Build CK
https://github.com/RadeonOpenCompute/rocm-cmake
## Build
```bash ```bash
mkdir build && cd build mkdir build && cd build
```
```bash # Need to specify target ID, example below is for gfx908 and gfx90a
# Need to specify target ID, example below is gfx908 and gfx90a cmake \
cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D BUILD_DEV=OFF \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS="-O3" \
-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \ -D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D GPU_TARGETS=gfx908;gfx90a \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_INSTALL_PREFIX=${PATH_TO_CK_INSTALL_DIRECTORY} \
.. ..
``` ```
### Build and Run Examples ### Build examples and tests
```bash
make -j examples
```
Instructions for running each individual examples are under ```example/```
## Tests
```bash ```bash
make -j examples tests make -j examples tests
make test make test
``` ```
Instructions for running each individual examples are under [example](/example)
## Build ckProfiler ## Build ckProfiler
```bash ```bash
make -j ckProfiler make -j ckProfiler
``` ```
Instructions for running ckProfiler are under ```profiler/``` Instructions for running ckProfiler are under [profiler](/profiler)
## Install CK ## Install CK
```bash ```bash
...@@ -54,13 +84,13 @@ make install ...@@ -54,13 +84,13 @@ make install
``` ```
## Using CK as pre-built kernel library ## Using CK as pre-built kernel library
Instructions for using CK as a pre-built kernel library are under ```client_example/``` Instructions for using CK as a pre-built kernel library are under [client_example](/client_example)
## Caveat ## Caveat
### Kernel Timing and Verification ### Kernel Timing and Verification
CK's own kernel timer will warn up kernel once, and then run it multiple times CK's own kernel timer will warn up kernel once, and then run it multiple times
to get average kernel time. For some kernels that use atomic add, this will cause to get average kernel time. For some kernels that use atomic add, this will cause
output buffer to be accumulated multiple times, causing verfication failure. output buffer to be accumulated multiple times, causing verification failure.
To work around it, do not use CK's own timer and do verification at the same time. To work around it, do not use CK's own timer and do verification at the same time.
CK's own timer and verification in each example and ckProfiler can be enabled or CK's own timer and verification in each example and ckProfiler can be enabled or
disabled from command line. disabled from command line.
...@@ -81,8 +81,8 @@ int main(int argc, char* argv[]) ...@@ -81,8 +81,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
{Stride, 1}, // xStrides {Stride, 1}, // xStrides
{1}, // gammaStrides {0, 1}, // gammaStrides
{1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
......
...@@ -6,9 +6,10 @@ find_package(composable_kernel 1.0.0 COMPONENTS device_operations) ...@@ -6,9 +6,10 @@ find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
find_package(hip REQUIRED PATHS /opt/rocm) find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(01_gemm) # add all example subdir
add_subdirectory(02_gemm_add_add_fastgelu) file(GLOB dir_list LIST_DIRECTORIES true *)
add_subdirectory(03_gemm_layernorm) FOREACH(subdir ${dir_list})
add_subdirectory(04_contraction) IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build"))
add_subdirectory(05_layernorm) add_subdirectory(${subdir})
add_subdirectory(06_softmax) ENDIF()
ENDFOREACH()
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl
// clang-format off
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
const int M = 256;
const int N = 128;
const int K = 64;
const int stride_A = K;
const int stride_B = K;
const int batch_stride_A = M * K;
const int batch_stride_B = K * N;
const int G0 = 16;
const int G1 = 8;
const int batch_count = G0 * G1;
// output layout - [G0, M, G1, N]
const int stride_G0 = M * G1 * N;
const int stride_G1 = N;
const int stride_M = G1 * N;
const int stride_N = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
// GEMM shape
ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
auto f_host_e_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_,
std::size_t M_,
std::size_t N_,
std::size_t stride_G0_,
std::size_t stride_G1_,
std::size_t stride_M_,
std::size_t stride_N_) {
return HostTensorDescriptor(
std::vector<std::size_t>({G0_, G1_, M_, N_}),
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
};
Tensor<EDataType> e_g0_g1_m_n_host_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<EDataType> e_g0_g1_m_n_device_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
batch_stride_A,
batch_stride_B,
batched_gemm_e_permute_desc,
batch_count,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
for(int g0 = 0; g0 < G0; g0++)
{
for(int g1 = 0; g1 < G1; g1++)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData,
e_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c");
}
return pass ? 0 : 1;
}
...@@ -29,24 +29,27 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,24 +29,27 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
using DeviceInstance = ck::tensor_operation::device::DeviceLayernormImpl<XDataType, using DeviceInstance =
GammaDataType, ck::tensor_operation::device::DeviceLayernormImpl<XDataType,
BetaDataType, GammaDataType,
AccDataType, BetaDataType,
YDataType, AccDataType,
PassThrough, YDataType,
Rank, PassThrough,
NumReduceDim, Rank,
256, // BlockSize NumReduceDim,
8, // ClusterM 256, // BlockSize
32, // ClusterK 8, // ClusterM
1, // SliceM 32, // ClusterK
8, // SliceK 1, // SliceM
1, // SrcVecDim (0=M, 1=K) 8, // SliceK
8, // SrcScalarPerVector 1, // SrcVecDim (0=M, 1=K)
8, // GammaScalarPerVector 8, // SrcScalarPerVector
8, // BetaScalarPerVector 1, // GammaVecDim (0=M, 1=K)
8>; // OutScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // OutScalarPerVector
int main() int main()
{ {
...@@ -88,8 +91,8 @@ int main() ...@@ -88,8 +91,8 @@ int main()
auto argument_ptr = device_instance.MakeArgumentPointer( auto argument_ptr = device_instance.MakeArgumentPointer(
{M, N}, {M, N},
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()}, {0, 1},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
......
...@@ -137,7 +137,7 @@ int main(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NW_C; using InLayout = ctc::G_NW_C;
using WeiLayout = ctc::G_K_X_C; using WeiLayout = ctc::G_K_X_C;
using BiasLayout = ctc::G_NW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NW_K; using ResidualLayout = ctc::G_NW_K;
using OutLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K;
...@@ -220,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NHW_C; using InLayout = ctc::G_NHW_C;
using WeiLayout = ctc::G_K_YX_C; using WeiLayout = ctc::G_K_YX_C;
using BiasLayout = ctc::G_NHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NHW_K; using ResidualLayout = ctc::G_NHW_K;
using OutLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K;
...@@ -332,7 +332,7 @@ int main(int argc, char* argv[]) ...@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NDHW_C; using InLayout = ctc::G_NDHW_C;
using WeiLayout = ctc::G_K_ZYX_C; using WeiLayout = ctc::G_K_ZYX_C;
using BiasLayout = ctc::G_NDHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NDHW_K; using ResidualLayout = ctc::G_NDHW_K;
using OutLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K;
......
...@@ -137,7 +137,7 @@ int main(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NW_C; using InLayout = ctc::G_NW_C;
using WeiLayout = ctc::G_K_X_C; using WeiLayout = ctc::G_K_X_C;
using BiasLayout = ctc::G_NW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NW_K; using ResidualLayout = ctc::G_NW_K;
using OutLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K;
...@@ -220,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NHW_C; using InLayout = ctc::G_NHW_C;
using WeiLayout = ctc::G_K_YX_C; using WeiLayout = ctc::G_K_YX_C;
using BiasLayout = ctc::G_NHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NHW_K; using ResidualLayout = ctc::G_NHW_K;
using OutLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K;
...@@ -332,7 +332,7 @@ int main(int argc, char* argv[]) ...@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NDHW_C; using InLayout = ctc::G_NDHW_C;
using WeiLayout = ctc::G_K_ZYX_C; using WeiLayout = ctc::G_K_ZYX_C;
using BiasLayout = ctc::G_NDHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NDHW_K; using ResidualLayout = ctc::G_NDHW_K;
using OutLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K;
......
...@@ -137,7 +137,7 @@ int main(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NW_C; using InLayout = ctc::G_NW_C;
using WeiLayout = ctc::G_K_X_C; using WeiLayout = ctc::G_K_X_C;
using BiasLayout = ctc::G_NW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NW_K; using ResidualLayout = ctc::G_NW_K;
using OutLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K;
...@@ -220,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NHW_C; using InLayout = ctc::G_NHW_C;
using WeiLayout = ctc::G_K_YX_C; using WeiLayout = ctc::G_K_YX_C;
using BiasLayout = ctc::G_NHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NHW_K; using ResidualLayout = ctc::G_NHW_K;
using OutLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K;
...@@ -332,7 +332,7 @@ int main(int argc, char* argv[]) ...@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NDHW_C; using InLayout = ctc::G_NDHW_C;
using WeiLayout = ctc::G_K_ZYX_C; using WeiLayout = ctc::G_K_ZYX_C;
using BiasLayout = ctc::G_NDHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NDHW_K; using ResidualLayout = ctc::G_NDHW_K;
using OutLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K;
......
...@@ -137,7 +137,7 @@ int main(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NW_C; using InLayout = ctc::G_NW_C;
using WeiLayout = ctc::G_K_X_C; using WeiLayout = ctc::G_K_X_C;
using BiasLayout = ctc::G_NW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NW_K; using ResidualLayout = ctc::G_NW_K;
using OutLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K;
...@@ -220,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NHW_C; using InLayout = ctc::G_NHW_C;
using WeiLayout = ctc::G_K_YX_C; using WeiLayout = ctc::G_K_YX_C;
using BiasLayout = ctc::G_NHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NHW_K; using ResidualLayout = ctc::G_NHW_K;
using OutLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K;
...@@ -332,7 +332,7 @@ int main(int argc, char* argv[]) ...@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NDHW_C; using InLayout = ctc::G_NDHW_C;
using WeiLayout = ctc::G_K_ZYX_C; using WeiLayout = ctc::G_K_ZYX_C;
using BiasLayout = ctc::G_NDHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NDHW_K; using ResidualLayout = ctc::G_NDHW_K;
using OutLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K;
......
...@@ -137,7 +137,7 @@ int main(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NW_C; using InLayout = ctc::G_NW_C;
using WeiLayout = ctc::G_K_X_C; using WeiLayout = ctc::G_K_X_C;
using BiasLayout = ctc::G_NW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NW_K; using ResidualLayout = ctc::G_NW_K;
using OutLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K;
...@@ -220,7 +220,7 @@ int main(int argc, char* argv[]) ...@@ -220,7 +220,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NHW_C; using InLayout = ctc::G_NHW_C;
using WeiLayout = ctc::G_K_YX_C; using WeiLayout = ctc::G_K_YX_C;
using BiasLayout = ctc::G_NHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NHW_K; using ResidualLayout = ctc::G_NHW_K;
using OutLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K;
...@@ -332,7 +332,7 @@ int main(int argc, char* argv[]) ...@@ -332,7 +332,7 @@ int main(int argc, char* argv[])
{ {
using InLayout = ctc::G_NDHW_C; using InLayout = ctc::G_NDHW_C;
using WeiLayout = ctc::G_K_ZYX_C; using WeiLayout = ctc::G_K_ZYX_C;
using BiasLayout = ctc::G_NDHW_K; using BiasLayout = ctc::G_K;
using ResidualLayout = ctc::G_NDHW_K; using ResidualLayout = ctc::G_NDHW_K;
using OutLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K;
......
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
...@@ -16,7 +16,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -16,7 +16,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -47,7 +48,9 @@ using CDataType = F16; ...@@ -47,7 +48,9 @@ using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using B0Layout = Col; using B0Layout = Col;
using B1Layout = Row; using B1Layout = Row;
using CLayout = Row;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
...@@ -55,65 +58,67 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -55,65 +58,67 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< using DeviceGemmInstance =
ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
B0Layout, ALayout,
B1Layout, B0Layout,
CLayout, B1Layout,
ADataType, CPermuteNumDims_G_M_O,
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
AccDataType, CDataType,
CShuffleDataType, AccDataType,
AElementOp, CShuffleDataType,
B0ElementOp, AElementOp,
Acc0ElementOp, B0ElementOp,
B1ElementOp, Acc0ElementOp,
CElementOp, B1ElementOp,
MNPadding, CElementOp,
1, GemmSpec,
256, 1,
128, // MPerBlock 256,
128, // NPerBlock 128, // MPerBlock
32, // KPerBlock 128, // NPerBlock
64, // Gemm1NPerBlock 32, // KPerBlock
32, // Gemm1KPerBlock 64, // Gemm1NPerBlock
8, // AK1 32, // Gemm1KPerBlock
8, // BK1 8, // AK1
2, // B1K1 8, // BK1
32, // MPerXDL 2, // B1K1
32, // NPerXDL 32, // MPerXDL
1, // MXdlPerWave 32, // NPerXDL
4, // NXdlPerWave 1, // MXdlPerWave
2, // Gemm1NXdlPerWave 4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer 2, // Gemm1NXdlPerWave
S<1, 0, 2>, S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
S<4, 64, 1>, // BBlockTransfer true,
S<1, 0, 2>, S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
S<16, 16, 1>, // B1BlockTransfer true,
S<0, 2, 1>, S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
1, S<0, 2, 1>,
4, 1,
2, 4,
false, 2,
1, // CShuffleMXdlPerWavePerShuffle false,
2, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 2, // CShuffleNXdlPerWavePerShuffle
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
...@@ -143,22 +148,26 @@ int main(int argc, char* argv[]) ...@@ -143,22 +148,26 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape for A/B0/B1/C
ck::index_t M = 1020; // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t N = 1020; ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1; ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1; ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1; ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1; ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1; ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1; ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
float alpha = 1; float alpha = 1;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1 = 13;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -169,74 +178,51 @@ int main(int argc, char* argv[]) ...@@ -169,74 +178,51 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 9) else if(argc == 11)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
BatchCount = std::stoi(argv[8]); alpha = std::stof(argv[10]);
}
else if(argc == 18)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " printf("arg4 to 11: M, N, K, O, G0, G1\n");
"BatchStrideB0, BatchStrideB1, BatchStrideC\n"); printf("arg10: scale (alpha)\n");
printf("arg17: scale (alpha)\n");
exit(0); exit(0);
} }
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N; const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA; const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0; const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1; const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
const int BatchCount = G0 * G1;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
...@@ -263,15 +249,17 @@ int main(int argc, char* argv[]) ...@@ -263,15 +249,17 @@ int main(int argc, char* argv[])
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o( Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result( Tensor<CDataType> c_gs_ms_os_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
Tensor<CDataType> c_g_m_o_device_result( std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -300,8 +288,8 @@ int main(int argc, char* argv[]) ...@@ -300,8 +288,8 @@ int main(int argc, char* argv[])
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize()); c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
...@@ -320,20 +308,20 @@ int main(int argc, char* argv[]) ...@@ -320,20 +308,20 @@ int main(int argc, char* argv[])
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
O, O,
BatchCount, BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA, StrideA,
StrideB0, StrideB0,
StrideB1, StrideB1,
StrideC,
BatchStrideA, BatchStrideA,
BatchStrideB0, BatchStrideB0,
BatchStrideB1, BatchStrideB1,
BatchStrideC,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
...@@ -361,26 +349,37 @@ int main(int argc, char* argv[]) ...@@ -361,26 +349,37 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
// Output of Gemm0 is input A of Gemm1 // Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
std::vector<int>{M * O, O, 1});
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
// gemm 0
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
// softmax
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
...@@ -388,9 +387,22 @@ int main(int argc, char* argv[]) ...@@ -388,9 +387,22 @@ int main(int argc, char* argv[])
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
// gemm1
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1; // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
? 0
: 1;
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment