"src/include/threadwise_direct_convolution.hpp" did not exist on "6790b8f3cc4fd32a9d9a43c6c9d80b826d969980"
Commit f76c0072 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

Merge remote-tracking branch 'origin/develop' into jakpiase/ggemm_multid_two_stage

parents f27c50a7 9c052804
...@@ -3,6 +3,7 @@ ARG DEBIAN_FRONTEND=noninteractive ...@@ -3,6 +3,7 @@ ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=6.0 ARG ROCMVERSION=6.0
ARG compiler_version="" ARG compiler_version=""
ARG compiler_commit="" ARG compiler_commit=""
ARG CK_SCCACHE=""
RUN set -xe RUN set -xe
...@@ -32,16 +33,18 @@ RUN if [ "$ROCMVERSION" != "6.1" ]; then \ ...@@ -32,16 +33,18 @@ RUN if [ "$ROCMVERSION" != "6.1" ]; then \
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
RUN amdgpu-install -y --usecase=rocm --no-dkms RUN amdgpu-install -y --usecase=rocm --no-dkms
## Sccache binary built from source for ROCm ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
RUN mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
ENV CK_SCCACHE=$CK_SCCACHE
RUN if [ "$CK_SCCACHE" != "" ]; then \
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
fi
# Install dependencies # Install dependencies
# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
build-essential \ build-essential \
cmake \ cmake \
...@@ -61,7 +64,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -61,7 +64,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
python3-dev \ python3-dev \
python3-pip \ python3-pip \
redis \ redis \
rocm-llvm-dev \
sshpass \ sshpass \
stunnel \ stunnel \
software-properties-common \ software-properties-common \
...@@ -75,6 +77,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -75,6 +77,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1
RUN if [ "$ROCMVERSION" = "6.1" ]; then \
sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \
fi
# Update the cmake to version 3.27.5 # Update the cmake to version 3.27.5
RUN pip install --upgrade cmake==3.27.5 RUN pip install --upgrade cmake==3.27.5
......
...@@ -104,7 +104,7 @@ def getDockerImage(Map conf=[:]){ ...@@ -104,7 +104,7 @@ def getDockerImage(Map conf=[:]){
env.DOCKER_BUILDKIT=1 env.DOCKER_BUILDKIT=1
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
def no_cache = conf.get("no_cache", false) def no_cache = conf.get("no_cache", false)
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if(no_cache) if(no_cache)
{ {
dockerArgs = dockerArgs + " --no-cache " dockerArgs = dockerArgs + " --no-cache "
...@@ -134,7 +134,7 @@ def buildDocker(install_prefix){ ...@@ -134,7 +134,7 @@ def buildDocker(install_prefix){
checkout scm checkout scm
def image_name = getDockerImageName() def image_name = getDockerImageName()
echo "Building Docker for ${image_name}" echo "Building Docker for ${image_name}"
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
echo "Build Args: ${dockerArgs}" echo "Build Args: ${dockerArgs}"
try{ try{
...@@ -311,7 +311,7 @@ def buildHipClangJob(Map conf=[:]){ ...@@ -311,7 +311,7 @@ def buildHipClangJob(Map conf=[:]){
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
} }
...@@ -367,9 +367,6 @@ def runCKProfiler(Map conf=[:]){ ...@@ -367,9 +367,6 @@ def runCKProfiler(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
......
...@@ -12,6 +12,11 @@ if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) ...@@ -12,6 +12,11 @@ if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp)
target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::bf8_t;
using OutDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
ck::bf8_t>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
rocm-docs-core==0.36.0 rocm-docs-core==0.37.0
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -96,9 +96,7 @@ pygments==2.15.0 ...@@ -96,9 +96,7 @@ pygments==2.15.0
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt[crypto]==2.6.0 pyjwt[crypto]==2.6.0
# via # via pygithub
# pygithub
# pyjwt
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pytz==2023.3.post1 pytz==2023.3.post1
...@@ -113,7 +111,7 @@ requests==2.31.0 ...@@ -113,7 +111,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.36.0 rocm-docs-core==0.37.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
...@@ -7,6 +7,7 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -7,6 +7,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1) set(target 1)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::bf8_t;
using AccDataType = float;
using CShuffleDataType = ck::f8_t;
using OutDataType = ck::f8_t;
using ComputeType = ck::bf8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
ComputeType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -20,15 +20,20 @@ using F32 = float; ...@@ -20,15 +20,20 @@ using F32 = float;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // Elementwise op PassThrough, // Elementwise
4, // NumDim 4, // NumDim
8, // MPerThread 256, // BlockSize
ck::Sequence<8>, // InScalarPerVectorSeq 128, // M0PerBlock
ck::Sequence<1>>; // OutScalarPerVectorSeq 128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor> template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -21,26 +21,23 @@ using F32 = float; ...@@ -21,26 +21,23 @@ using F32 = float;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
using Scale = ck::tensor_operation::element_wise::Scale; ck::Tuple<ADataType>, // InDataTypeTuple
using DeviceElementwisePermuteInstance = ck::Tuple<BDataType>, // OutDataTypeTuple
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple UnaryOp, // UnaryOp
ck::Tuple<BDataType>, // OutDataTypeTuple 4, // NumDim
PassThrough, // ElementwiseOp 256, // BlockSize
UnaryOp, // UnaryOp 128, // M0PerBlock
Scale, // Scalar 128, // M1PerBlock
4, // NumDim 8, // M0PerThread
8, // MPerThread 8, // M1PerThread
ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<1>>; // OutScalarPerVectorSeq ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc, template <typename HostTensorA, typename HostTensorB, typename Functor>
const HostTensorA& A_nchw, void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
FunctorA functor_a,
FunctorB functor_b,
float scale)
{ {
std::size_t N = A_nchw.mDesc.GetLengths()[0]; std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1]; std::size_t C = A_nchw.mDesc.GetLengths()[1];
...@@ -51,11 +48,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -51,11 +48,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for(std::size_t c = 0; c < C; ++c) for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
ADataType tmp_val;
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val); functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val);
} }
} }
...@@ -104,14 +98,8 @@ int main() ...@@ -104,14 +98,8 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin()); ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, auto argument = broadcastPermute.MakeArgumentPointer(
{a_strides}, ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -143,7 +131,7 @@ int main() ...@@ -143,7 +131,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); host_elementwise4D(host_b, a, UnaryOp{scale});
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -20,36 +20,31 @@ using F32 = float; ...@@ -20,36 +20,31 @@ using F32 = float;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
using Scale = ck::tensor_operation::element_wise::Scale; ck::Tuple<ADataType>, // InDataTypeTuple
using DeviceElementwisePermuteInstance = ck::Tuple<BDataType>, // OutDataTypeTuple
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple UnaryOp, // UnaryOp
ck::Tuple<BDataType>, // OutDataTypeTuple 4, // NumDim
PassThrough, // ElementwiseOp 256, // BlockSize
UnaryOp, // UnaryOp 128, // M0PerBlock
Scale, // Scalar 128, // M1PerBlock
4, // NumDim 8, // M0PerThread
8, // MPerThread 8, // M1PerThread
ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<1>>; // OutScalarPerVectorSeq ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc, template <typename HostTensorA, typename HostTensorB, typename Functor>
const HostTensorA& A_nchw, void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
FunctorA functor_a,
FunctorB functor_b,
float scale)
{ {
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{ {
ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val); functor(B_nhwc(n, h, w, c), a_val);
functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
} }
} }
...@@ -86,14 +81,8 @@ int main() ...@@ -86,14 +81,8 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin()); ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, auto argument = broadcastPermute.MakeArgumentPointer(
{a_strides}, ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -125,7 +114,7 @@ int main() ...@@ -125,7 +114,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); host_elementwise4D(host_b, a, UnaryOp{scale});
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -20,26 +20,23 @@ using F32 = float; ...@@ -20,26 +20,23 @@ using F32 = float;
using ADataType = F32; using ADataType = F32;
using BDataType = F32; using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
using Scale = ck::tensor_operation::element_wise::Scale; ck::Tuple<ADataType>, // InDataTypeTuple
using DeviceElementwisePermuteInstance = ck::Tuple<BDataType>, // OutDataTypeTuple
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple UnaryOp, // UnaryOp
ck::Tuple<BDataType>, // OutDataTypeTuple 4, // NumDim
PassThrough, // ElementwiseOp 256, // BlockSize
UnaryOp, // UnaryOp 128, // M0PerBlock
Scale, // Scalar 128, // M1PerBlock
4, // NumDim 8, // M0PerThread
1, // MPerThread 8, // M1PerThread
ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<1>>; // OutScalarPerVectorSeq ck::Sequence<1>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc, template <typename HostTensorA, typename HostTensorB, typename Functor>
const HostTensorA& A_nchw, void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
FunctorA functor_a,
FunctorB functor_b,
float scale)
{ {
std::size_t N = A_nchw.mDesc.GetLengths()[0]; std::size_t N = A_nchw.mDesc.GetLengths()[0];
std::size_t C = A_nchw.mDesc.GetLengths()[1]; std::size_t C = A_nchw.mDesc.GetLengths()[1];
...@@ -50,11 +47,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -50,11 +47,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for(std::size_t c = 0; c < C; ++c) for(std::size_t c = 0; c < C; ++c)
for(std::size_t n = 0; n < N; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
ADataType tmp_val;
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val); functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val);
} }
} }
...@@ -104,14 +98,8 @@ int main() ...@@ -104,14 +98,8 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin()); ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, auto argument = broadcastPermute.MakeArgumentPointer(
{a_strides}, ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -143,7 +131,7 @@ int main() ...@@ -143,7 +131,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); host_elementwise4D(host_b, a, UnaryOp{scale});
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -20,36 +20,31 @@ using F32 = float; ...@@ -20,36 +20,31 @@ using F32 = float;
using ADataType = F32; using ADataType = F32;
using BDataType = F32; using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using UnaryOp = ck::tensor_operation::element_wise::Scale;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
using Scale = ck::tensor_operation::element_wise::Scale; ck::Tuple<ADataType>, // InDataTypeTuple
using DeviceElementwisePermuteInstance = ck::Tuple<BDataType>, // OutDataTypeTuple
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple UnaryOp, // UnaryOp
ck::Tuple<BDataType>, // OutDataTypeTuple 4, // NumDim
PassThrough, // ElementwiseOp 256, // BlockSize
UnaryOp, // UnaryOp 128, // M0PerBlock
Scale, // Scalar 128, // M1PerBlock
4, // NumDim 8, // M0PerThread
8, // MPerThread 8, // M1PerThread
ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<1>>; // OutScalarPerVectorSeq ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc, template <typename HostTensorA, typename HostTensorB, typename Functor>
const HostTensorA& A_nchw, void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
FunctorA functor_a,
FunctorB functor_b,
float scale)
{ {
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{ {
ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val); functor(B_nhwc(n, h, w, c), a_val);
functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
} }
} }
...@@ -86,14 +81,8 @@ int main() ...@@ -86,14 +81,8 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin()); ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, auto argument = broadcastPermute.MakeArgumentPointer(
{a_strides}, ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -125,7 +114,7 @@ int main() ...@@ -125,7 +114,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); host_elementwise4D(host_b, a, UnaryOp{scale});
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp"
namespace ck {
/**
* @brief Blockwise data transfer
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
typename SrcsScalarPerVector, // Sequence
typename DstsScalarPerVector, // Sequence
typename SrcsScalarStrideInVector, // Sequence
typename DstsScalarStrideInVector, // Sequence
typename ThreadTransferSrcsResetCoordinateAfterRun, // Sequence
typename ThreadTransferDstsResetCoordinateAfterRun, // Sequence
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r2
{
static constexpr index_t nDim =
remove_reference_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r2(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_for<0, nSrc, 1>{}([&](auto src_i) {
static_assert(nDim ==
remove_cvref_t<tuple_element_t<src_i, SrcDescs>>::GetNumOfDimension(),
"wrong! nDim not consistent");
});
static_for<0, nDst, 1>{}([&](auto dst_i) {
static_assert(nDim ==
remove_cvref_t<tuple_element_t<dst_i, DstDescs>>::GetNumOfDimension(),
"wrong! nDim not consistent");
});
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
Number<nSrc>{});
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers& dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffer& src_bufs,
const DstDescs& dst_descs,
DstBuffer& dst_bufs,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_descs, src_bufs, thread_scratch_id);
RunWrite(dst_descs, dst_bufs, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<decltype(thread_slice_lengths),
ElementwiseOperation,
DstInMemOps,
SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcsScalarPerVector,
DstsScalarPerVector,
SrcsScalarStrideInVector,
DstsScalarStrideInVector,
ThreadTransferSrcsResetCoordinateAfterRun,
ThreadTransferDstsResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim,
index_t BlockSize,
index_t M0PerBlock,
index_t M1PerBlock,
index_t M0PerThread,
index_t M1PerThread,
typename ThreadClusterArrangeOrder,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
{
index_t most_continous_dim = NumDim - 1;
index_t most_continous_dim_stride = strides[most_continous_dim];
for(index_t dim = 0; dim < NumDim; dim++)
{
if(strides[dim] < most_continous_dim_stride)
{
most_continous_dim_stride = strides[dim];
most_continous_dim = dim;
}
}
return most_continous_dim;
}
template <typename InOutDescriptor>
static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
{
const auto M0 = desc.GetLength(I0);
const auto M1 = desc.GetLength(I1);
const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return padded_desc;
}
static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
const index_t M0_dim,
const index_t M1_dim)
{
// Generate batch dims, they will be merged to M0
// Add one more dim than needed in case that M0 is equal to M1
// If M0 is equal to M1, then will be one more batch dim
std::array<index_t, NumDim - 1> batch_dims;
index_t batch_dim = 0;
for(index_t i = 0; i < NumDim; i++)
{
if(i != M0_dim && i != M1_dim)
{
batch_dims[batch_dim] = lengths[i];
batch_dim++;
}
}
// Add dummy dim if M0_dim is not equal to M1_dim
if(M0_dim != M1_dim && NumDim >= 2)
batch_dims[NumDim - 2] = 1;
return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
}
static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& in_strides,
const std::array<index_t, NumDim>& out_strides,
const std::array<index_t, NumDim>& desc_strides)
{
const auto M0_dim = GetLowestStrideDim(out_strides);
const auto M1_dim = GetLowestStrideDim(in_strides);
// If M0_dim is equal to M1_dim, then make M0_dim dummy
const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
const auto M1 = lengths[M1_dim];
const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
const auto M1_stride = desc_strides[M1_dim];
const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
const auto desc = make_naive_tensor_descriptor(
concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
// Merged batch dims with M0
const auto transforms =
make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
make_pass_through_transform(M1));
using BatchElemsSequence =
typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
// desc: (merged_dims + M0, M1)
auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
return PadInputOutputDescriptor(merged_desc);
}
template <index_t NumTensors>
static auto GenerateInOutGridDescTuple()
{
std::array<index_t, NumDim> ones;
for(index_t d = 0; d < NumDim; d++)
{
ones[d] = 1;
}
return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
Number<NumTensors>{});
};
using InGridDescTuple = decltype(GenerateInOutGridDescTuple<NumInput>());
using OutGridDescTuple = decltype(GenerateInOutGridDescTuple<NumOutput>());
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<M0PerBlock, M1PerBlock>;
using GridwiseElementwiseOp = GridwiseElementwise<InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation,
BlockSize,
M0PerBlock,
M1PerBlock,
M0PerThread,
M1PerThread,
ThreadClusterArrangeOrder,
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
false>;
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation,
BlockSize,
M0PerBlock,
M1PerBlock,
M0PerThread,
M1PerThread,
ThreadClusterArrangeOrder,
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
true>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op)
{
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto in_grid_desc_tuple = generate_tuple(
[&](auto src_i) {
// Use Strides from first tensor to assert that M0 dim and
// M1 dim are the same for each tensor.
return MakeDescriptor(arg.lengths_,
arg.inStridesArray_[I0],
arg.outStridesArray_[I0],
arg.inStridesArray_[src_i]);
},
Number<NumInput>{});
auto out_grid_desc_tuple = generate_tuple(
[&](auto dst_i) {
return MakeDescriptor(arg.lengths_,
arg.inStridesArray_[I0],
arg.outStridesArray_[I0],
arg.outStridesArray_[dst_i]);
},
Number<NumOutput>{});
const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
const auto block_2_tile_map = Block2TileMap(M0, M1);
const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
GetLowestStrideDim(arg.outStridesArray_[I0]);
const auto kernel = in_out_same_vector_dim
? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation>
: kernel_elementwise<GridwiseElementwiseOp,
InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
in_grid_desc_tuple,
out_grid_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
block_2_tile_map,
arg.elementwise_op_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t M_dim) {
if(scalarPerVector == 1)
{
return true;
}
if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
{
return true;
}
return false;
};
bool is_valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
M1PerThread % InScalarPerVectorSeq::At(I) == 0);
is_valid &= IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
is_valid &= IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
});
return is_valid;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseImpl<";
str << NumDim << ", ";
str << BlockSize << ", ";
str << M0PerBlock << ", ";
str << M1PerBlock << ", ";
str << M0PerThread << ", ";
str << M1PerThread << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/common_header.hpp"
namespace ck {
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
const OutGridDescTuple out_grid_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const Block2TileMap block_2_tile_map,
const ElementwiseOperation elementwise_op)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
out_grid_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
block_2_tile_map,
elementwise_op);
}
template <typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation,
index_t BlockSize,
index_t M0PerBlock,
index_t M1PerBlock,
index_t M0PerThread,
index_t M1PerThread,
typename ThreadClusterArrangeOrder,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq,
bool InOutSameVectorDim>
struct GridwiseElementwise
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
const OutGridDescTuple& out_grid_desc_tuple,
const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op)
{
constexpr auto src_datas = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return DataType{};
},
Number<NumInput>{});
constexpr auto dst_datas = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return DataType{};
},
Number<NumOutput>{});
const auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
const index_t m1_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
const auto thread_grid_offset =
make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// If src and dst have same vector dim, then:
// M0 dim - for src and dst vector load/store
// else:
// M0 dim - for dst vector load
// M1 dim - for src vector store
using SrcDimAccessOrder = Sequence<0, 1>;
using DstDimAccessOrder =
std::conditional_t<InOutSameVectorDim, Sequence<0, 1>, Sequence<1, 0>>;
using SrcVectorDim = Number<1>;
using DstVectorDim = std::conditional_t<InOutSameVectorDim, Number<1>, Number<0>>;
using ThreadClusterLengths =
Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
ThisThreadBlock,
ElementwiseOperation,
uniform_sequence_gen_t<NumOutput, static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<M0PerBlock, M1PerBlock>,
ThreadClusterLengths,
ThreadClusterArrangeOrder,
decltype(src_datas),
decltype(dst_datas),
InGridDescTuple,
OutGridDescTuple,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim{},
DstVectorDim{},
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
uniform_sequence_gen_t<NumInput, 1>,
uniform_sequence_gen_t<NumOutput, 1>,
uniform_sequence_gen_t<NumInput, false>,
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
thread_grid_offset,
out_grid_desc_tuple,
thread_grid_offset,
elementwise_op};
global_to_global_transfer.Run(
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
}
};
} // namespace ck
...@@ -17,6 +17,10 @@ namespace instance { ...@@ -17,6 +17,10 @@ namespace instance {
using F8 = ck::f8_t; using F8 = ck::f8_t;
#endif #endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -250,6 +254,42 @@ using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< ...@@ -250,6 +254,42 @@ using device_grouped_conv_fwd_xdl_f8_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_BF8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>
#endif
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -744,6 +744,23 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( ...@@ -744,6 +744,23 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(
F8>>>& instances); F8>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF8,
BF8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
BF8>>>& instances);
#endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
...@@ -1159,6 +1176,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -1159,6 +1176,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF8
if constexpr(is_same_v<InDataType, ck::bf8_t> && is_same_v<WeiDataType, ck::bf8_t> &&
is_same_v<OutDataType, ck::f8_t> && is_same_v<ComputeType, ck::bf8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>) is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,125 +19,67 @@ namespace instance { ...@@ -19,125 +19,67 @@ namespace instance {
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_permute_scale_1d_f16_instances( void add_device_permute_scale_1d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<
ck::Tuple<F16>, DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 1>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
1>>>&);
void add_device_permute_scale_2d_f16_instances( void add_device_permute_scale_2d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<
ck::Tuple<F16>, DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 2>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
2>>>&);
void add_device_permute_scale_3d_f16_instances( void add_device_permute_scale_3d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<
ck::Tuple<F16>, DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 3>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
3>>>&);
void add_device_permute_scale_4d_f16_instances( void add_device_permute_scale_4d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<
ck::Tuple<F16>, DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 4>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
4>>>&);
void add_device_permute_scale_5d_f16_instances( void add_device_permute_scale_5d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<
ck::Tuple<F16>, DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 5>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
5>>>&);
void add_device_permute_scale_6d_f16_instances( void add_device_permute_scale_6d_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, std::vector<std::unique_ptr<
ck::Tuple<F16>, DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 6>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
6>>>&);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_permute_scale_1d_f32_instances( void add_device_permute_scale_1d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<
ck::Tuple<F32>, DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 1>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
1>>>&);
void add_device_permute_scale_2d_f32_instances( void add_device_permute_scale_2d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<
ck::Tuple<F32>, DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 2>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
2>>>&);
void add_device_permute_scale_3d_f32_instances( void add_device_permute_scale_3d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<
ck::Tuple<F32>, DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 3>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
3>>>&);
void add_device_permute_scale_4d_f32_instances( void add_device_permute_scale_4d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<
ck::Tuple<F32>, DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 4>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
4>>>&);
void add_device_permute_scale_5d_f32_instances( void add_device_permute_scale_5d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<
ck::Tuple<F32>, DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 5>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
5>>>&);
void add_device_permute_scale_6d_f32_instances( void add_device_permute_scale_6d_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, std::vector<std::unique_ptr<
ck::Tuple<F32>, DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 6>>>&);
PassThrough,
element_wise::UnarySquare,
Scale,
6>>>&);
#endif #endif
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim> index_t NumDim>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceElementwise<InDataTypeTuple, ck::tensor_operation::device::
OutDataTypeTuple, DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>>
{ {
using DeviceOp = DeviceElementwise<InDataTypeTuple, using DeviceOp =
OutDataTypeTuple, DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>;
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>;
static auto GetInstances() static auto GetInstances()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment