Unverified Commit 29dcb956 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #33 from ROCm/lwpck-1292

Merge from the public repo.
parents 29deceb6 cbcc844e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
......@@ -67,6 +70,8 @@ int main()
float scale = 1.f;
auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
......@@ -74,7 +79,7 @@ int main()
{
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i;
i++;
i = dis(gen);
}
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
......
......@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
throw std::runtime_error("Pool3d_fwd: problem with layout. ");
return {0, 0, 0, 0, 0};
};
template <typename TensorLayout>
......@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
throw std::runtime_error("Pool3d_fwd: problem with layout. ");
return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0});
};
template <typename DevicePoolFwdInstance,
......
......@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
throw std::runtime_error("Avgpool3d_bwd: problem with layout. ");
return {0, 0, 0, 0, 0};
};
template <typename TensorLayout>
......@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
throw std::runtime_error("Avgpool3d_bwd: problem with layout. ");
return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0});
};
template <typename DevicePoolBwdInstance,
......
add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp)
......@@ -15,16 +15,17 @@
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using DYDataType = ck::half_t;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t;
using DBetaDataType = ck::half_t;
using DXDataType = ck::half_t;
using DGammaDataType = float;
using DBetaDataType = float;
using DXDataType = float;
using ComputeDataType = float;
constexpr int Rank = 2;
......@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
// inv_std: [M, 1]
// Output shape
// dx: [M, N]
// dgamma: [1, N]
// dbeta: [1, N]
......@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0)
// [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different
// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant
// dimension, K is reduced dimension Hence, M in this example and
// DeviceNormalizationBwdGammaBetaImpl is different
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType,
XDataType,
......@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterInvarient
32, // ClusterReduce
8, // SliceInvarient
1, // SliceReduce
8, // MThreadClusterSize
32, // KThreadClusterSize
4, // MThreadSliceSize
1, // KThreadSliceSize
false, // IsDYFastestDimReduced
8, // DYSrcVectorSize
4, // DYSrcVectorSize
false, // IsXFastestDimReduced
8, // XSrcVectorSize
4, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize
1>; // DBetaDstVectorSize
4, // DGammaDstVectorSize
4>; // DBetaDstVectorSize
int main()
{
......@@ -96,16 +124,48 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{0, 1}, // gammaStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N, 1}, // dxStrides
{1}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr =
gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths
......@@ -126,7 +186,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported" << std::endl;
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
......@@ -156,9 +217,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
}
return (pass ? 0 : 1);
......
add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp)
add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp)
add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp)
......@@ -15,23 +15,58 @@
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using DYDataType = ck::half_t;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t;
using DBetaDataType = ck::half_t;
using DXDataType = ck::half_t;
using DGammaDataType = float;
using DBetaDataType = float;
using DXDataType = float;
using ComputeDataType = float;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
// Grouprnorm
// kernel: M , K
// kernel 1: M , K
// dy: N, H, W, G, C -> N * G, H * W * C
// x: N, H, W, G, C -> N * G, H * W * C
// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C
// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// dx: N, H, W, G, C -> N * G, H * W * C
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
......@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterInvarient
8, // ClusterInvariant
32, // ClusterReduce
8, // SliceInvarient
4, // SliceInvariant
1, // SliceReduce
false, // IsDYFastestDimReduced
8, // DYSrcVectorSize
4, // DYSrcVectorSize
false, // IsXFastestDimReduced
8, // XSrcVectorSize
4, // XSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize
1>; // DBetaDstVectorSize
4, // DGammaDstVectorSize
4>; // DBetaDstVectorSize
int main()
{
......@@ -93,20 +128,55 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<ck::index_t> gammaStrides = {0, 0, 0, C, 1};
std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()};
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths
dyStrides, // dyStrides
xStrides, // xStrides
gammaStrides, // gammaStrides
meanStrides, // meanStrides
invStdStrides, // invStdStrides
dxStrides, // dxStrides
{1, 2, 4}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr =
......@@ -128,7 +198,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported" << std::endl;
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
......@@ -158,9 +229,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
}
return (pass ? 0 : 1);
......
......@@ -42,6 +42,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
constexpr ck::index_t NDimSpatial = 3;
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using OutDataType = ck::half_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename OutElementOp>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<OutLayout, BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<OutDataType, OutDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
namespace {
// Use custom implementation to pass two more tensors for post op
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
constexpr ck::index_t NumDs = 2;
const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0];
const ck::index_t K = out_g_n_k_wos_desc.GetLengths()[2];
// Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW)
std::array<ck::index_t, NDimSpatial + 3> bias_g_k_lengths;
std::array<ck::index_t, NDimSpatial + 3> bias_g_k_strides;
// Fill other lenghts than G,K with 1 and strides with 0
bias_g_k_lengths.fill(1);
bias_g_k_strides.fill(0);
bias_g_k_lengths[0] = G;
bias_g_k_lengths[2] = K;
bias_g_k_strides[0] = K; // stride to G
bias_g_k_strides[2] = 1; // stride to K
const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides);
// y = relu ( alpha1 * conv(x) + alpha2 * z + bias )
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::array<Tensor<OutDataType>, NumDs> d_tensors = {Tensor<OutDataType>(out_g_n_k_wos_desc),
Tensor<OutDataType>(broadcasted_bias_desc)};
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
std::cout << "z_tensor: " << d_tensors[0].mDesc << std::endl;
std::cout << "bias_tensor: " << d_tensors[1].mDesc << std::endl;
// Make sure that we allocated only G * K values for bias
assert(static_cast<ck::index_t>(d_tensors[1].mData.size()) == G * K);
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
d_tensors[0].GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
d_tensors[1].GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
d_tensors[0].GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.05, 0.05});
d_tensors[1].GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.05, 0.05});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem z_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize());
DeviceMem bias_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
z_buf.ToDevice(d_tensors[0].mData.data());
bias_buf.ToDevice(d_tensors[1].mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
const std::array<const void*, NumDs> ds = {z_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer()};
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
ds,
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, NumDs>{
e_g_n_k_wos_lengths, bias_g_k_lengths},
std::array<std::array<ck::index_t, NDimSpatial + 3>, NumDs>{
e_g_n_k_wos_strides, bias_g_k_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error("The device op with the specified compilation parameters does "
"not support this convolution problem.");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops() + G * K +
conv_param.GetOutputByte<OutDataType>() / sizeof(OutDataType);
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>() +
G * K * sizeof(OutDataType) + conv_param.GetOutputByte<OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv =
ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
0, /*Num A Elementwise Tensors*/
0, /*Num B Elementwise Tensors*/
NumDs>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op,
{},
{},
d_tensors);
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!");
}
return true;
}
} // namespace
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
......@@ -24,7 +24,7 @@ bool run_convnd_fwd_example(int argc, char* argv[])
// Following shapes are selected to avoid overflow. Expect inf in case of
// size increase for some elementwise ops.
ck::utils::conv::ConvParam conv_param{
3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
if(argc == 1)
{
......
......@@ -44,16 +44,30 @@
#define CK_USE_WAVES_PER_EU 0
#endif
// define general macros for various architectures
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
#define __gfx101__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
defined(__gfx90a__) || defined(__gfx94__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx11__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
......@@ -61,12 +75,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
defined(__gfx94__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#elif defined(__gfx11__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
......@@ -75,23 +89,22 @@
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
#define CK_USE_AMD_MFMA
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(defined(__gfx90a__) || defined(__gfx94__))
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
......@@ -107,15 +120,13 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
......@@ -134,6 +145,9 @@
// inner product using V_DOT with DPP8 modifiers
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
......@@ -215,7 +229,7 @@
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#elif
#else
// enable only on MI200
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
......
......@@ -26,7 +26,7 @@ inline std::string get_device_name()
}
const std::string raw_name(props.gcnArchName);
// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
static std::map<std::string, std::string> device_name_map = {
{"Ellesmere", "gfx803"},
{"Baffin", "gfx803"},
......@@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
inline bool is_navi1_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
ck::get_device_name() == "gfx1012";
}
inline bool is_navi2_supported()
{
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_navi3_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
} // namespace ck
......@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
if(x != hipSuccess)
{
std::ostringstream ss;
ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__
<< "in function: " << __func__;
ss << "HIP runtime error: " << hipGetErrorString(x) << ". "
<< "hip_check_error.hpp"
<< ": " << __LINE__ << "in function: " << __func__;
throw std::runtime_error(ss.str());
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" \
<< "hip_check_error.hpp" \
<< "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
......@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
printf("Warm up %d times\n", stream_config.cold_niters_);
#endif
// warm up
for(int i = 0; i < stream_config.cold_niters_; ++i)
......@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
printf("Warm up %d times\n", stream_config.cold_niters_);
#endif
// warm up
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
const int nrepeat = 10;
const int nrepeat = stream_config.nrepeat_;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
......
......@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
int log_level_ = 0;
int cold_niters_ = 1;
int nrepeat_ = 10;
int cold_niters_ = 5;
int nrepeat_ = 50;
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
// Double LDS buffer
// Prefetech 2 stage
// Local prefetch 1 stage
namespace ck {
template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t ABufferLoadWidth,
index_t BBufferLoadWidth,
index_t ALDSWriteWidth,
index_t BLDSWriteWidth,
index_t ALDSReadWidth,
index_t BLDSReadWidth,
index_t MRepeat,
index_t NRepeat,
index_t MPerXDL,
index_t NPerXDL,
index_t KPerXDL>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 64;
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
static constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
static constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
static constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
static constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
static constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
static constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
static constexpr auto Print()
{
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
BlockSize,
WaveSize,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num);
}
};
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v4
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
A_K1,
B_K1,
A_K1,
B_K1,
KPack,
KPack,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
// HotLoopInstList::Print();
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__device__ static constexpr auto HotLoopScheduler()
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst =
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
;
constexpr auto num_buffer_load_inst =
HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
});
}
template <index_t stage>
__device__ static constexpr auto TailScheduler()
{
}
template <>
__device__ static constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst =
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_issue = num_ds_write_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA
});
}
template <>
__device__ static constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_issue = num_ds_read_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA
});
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
index_t TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
// Inst List:
// ds_read_b128: 16
// ds_write_b128: 8
// buffer_load_dwordx4: 16
// v_mfma: 0
// -------------------------------------------------------------------------------------------
// Global prefetch 1th, Fill Ping LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Local prefetch 1th, Fill Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
});
// Global prefetch 2th, Fill Pong LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
// Global prefetch 3rd
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
// -------------------------------------------------------------------------------------------
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 3));
}
// tail
if constexpr(TailNum == 3)
{
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == 2)
{
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
......@@ -59,7 +59,9 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
virtual void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalizationBwdData : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationBwdDataPtr = std::unique_ptr<DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment