"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "9d30f9f8b3836e8d617eadf63a71d8363ff56c7e"
Commit aa5859e4 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into wavelet_model

parents 9bd6cc0e 5ee30459
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -92,9 +92,9 @@ int main() ...@@ -92,9 +92,9 @@ int main()
a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace()); DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpaceSize());
DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace()); DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
a_m_n_device_buf.ToDevice(a_m_n.mData.data()); a_m_n_device_buf.ToDevice(a_m_n.mData.data());
b_n_device_buf.ToDevice(b_n.mData.data()); b_n_device_buf.ToDevice(b_n.mData.data());
......
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -74,9 +74,9 @@ int main() ...@@ -74,9 +74,9 @@ int main()
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace()); DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace()); DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
......
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -72,9 +72,9 @@ int main() ...@@ -72,9 +72,9 @@ int main()
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace()); DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpaceSize());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace()); DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data()); b_m_device_buf.ToDevice(b_m.mData.data());
......
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -74,9 +74,9 @@ int main() ...@@ -74,9 +74,9 @@ int main()
a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data()); b_device_buf.ToDevice(b.mData.data());
......
add_example_executable(example_convnd_bwd_weight_xdl_fp16 convnd_bwd_weight_xdl_fp16.cpp)
add_example_executable(example_convnd_bwd_weight_xdl_bf16 convnd_bwd_weight_xdl_bf16.cpp)
target_link_libraries(example_convnd_bwd_weight_xdl_fp16 PRIVATE utility)
target_link_libraries(example_convnd_bwd_weight_xdl_bf16 PRIVATE utility)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp"
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvBwdWeightInstance>
int run_conv_bwd_weight(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op,
ck::index_t split_k)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei_host_result(wei_g_k_c_xs_desc);
Tensor<WeiDataType> wei_device_result(wei_g_k_c_xs_desc);
Tensor<OutDataType> out(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei_host_result.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
out_device_buf.ToDevice(out.mData.data());
// init to 0
wei_device_buf.SetZero();
// do GEMM
auto conv = DeviceConvBwdWeightInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.output_spatial_lengths_,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op,
split_k);
if(!conv.IsSupportedArgument(argument))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei_host_result,
out,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
wei_device_buf.FromDevice(wei_device_result.mData.data());
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_bwd_weight_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp"
using InDataType = ck::bhalf_t;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
using WeiDataType = float;
using OutDataType = ck::bhalf_t;
using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
template <ck::index_t NDimSpatial>
using DeviceConvndBwdWeightInstance =
ck::tensor_operation::device::DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
int main(int argc, char* argv[])
{
namespace ctc = ck::tensor_layout::convolution;
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 32, 256, 1024, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
ck::index_t split_k = 4;
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
split_k = std::stoi(argv[5 + 3 + 6 * num_dim_spatial - 1]);
split_k = std::max(1, split_k);
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
if(conv_param.num_dim_spatial_ == 1)
{
using InLayout = ctc::GNWC;
using WeiLayout = ctc::GKXC;
using OutLayout = ctc::GNWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_conv_bwd_weight<1,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(conv_param.num_dim_spatial_ == 2)
{
using InLayout = ctc::GNHWC;
using WeiLayout = ctc::GKYXC;
using OutLayout = ctc::GNHWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_conv_bwd_weight<2,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(conv_param.num_dim_spatial_ == 3)
{
using InLayout = ctc::GNDHWC;
using WeiLayout = ctc::GKZYXC;
using OutLayout = ctc::GNDHWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_conv_bwd_weight<3,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_bwd_weight_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
template <ck::index_t NDimSpatial>
using DeviceConvndBwdWeightInstance =
ck::tensor_operation::device::DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
int main(int argc, char* argv[])
{
namespace ctc = ck::tensor_layout::convolution;
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 32, 256, 1024, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
ck::index_t split_k = 4;
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
split_k = std::stoi(argv[5 + 3 + 6 * num_dim_spatial - 1]);
split_k = std::max(1, split_k);
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
if(conv_param.num_dim_spatial_ == 1)
{
using InLayout = ctc::GNWC;
using WeiLayout = ctc::GKXC;
using OutLayout = ctc::GNWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_conv_bwd_weight<1,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(conv_param.num_dim_spatial_ == 2)
{
using InLayout = ctc::GNHWC;
using WeiLayout = ctc::GKYXC;
using OutLayout = ctc::GNHWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_conv_bwd_weight<2,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(conv_param.num_dim_spatial_ == 3)
{
using InLayout = ctc::GNDHWC;
using WeiLayout = ctc::GKZYXC;
using OutLayout = ctc::GNDHWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_conv_bwd_weight<3,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
return 0;
}
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp)
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
using DeviceConvBwdWeightBasePtr =
ck::tensor_operation::device::DeviceConvBwdWeightPtr<InElementOp, WeiElementOp, OutElementOp>;
// clang-format off
template <ck::index_t NumDimSpatial>
using DeviceConvndBwdWeightInstance = ck::tensor_operation::device::
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <ck::index_t NumDimSpatial>
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NumDimSpatial>;
void print_use_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: is show log (0=no, 1=yes)\n"
<< "arg5: split-k \n"
<< "arg6: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::utils::conv::ConvParams params;
int arg_idx = 7;
params.num_dim_spatial_ = num_dim_spatial;
params.N_ = std::stoi(argv[arg_idx++]);
params.K_ = std::stoi(argv[arg_idx++]);
params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
}
params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
}
params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
}
params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
}
return params;
}
DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 3: {
return std::make_unique<DeviceConvndBwdWeightInstance<3>>();
}
case 2: {
return std::make_unique<DeviceConvndBwdWeightInstance<2>>();
}
case 1: {
return std::make_unique<DeviceConvndBwdWeightInstance<1>>();
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int num_dim_spatial = 2;
int do_log = 0;
int split_k = 1;
ck::utils::conv::ConvParams params;
params.C_ = 128;
if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]);
}
else if(argc > 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]);
num_dim_spatial = std::stoi(argv[6]);
// check args number
int conv_args = 3 + num_dim_spatial * 6;
int cmdline_nargs = conv_args + 7;
if(cmdline_nargs != argc)
{
print_use_msg();
exit(1);
}
params = parse_conv_params(num_dim_spatial, argv);
}
else if(argc != 1)
{
print_use_msg();
exit(1);
}
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi(
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x_host_result(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x_device_result(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
// reset input to zero
wei_device_buf.SetZero();
// do GEMM
auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer();
auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
// alloc work space
float ave_time = 0.f;
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
params.N_,
params.C_,
params.K_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
if(do_verification)
{
auto verify_f = [&](const auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x_host_result,
out_n_k_ho_wo,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
if(do_log)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
}
return ck::utils::check_err(wei_k_c_y_x_device_result.mData,
wei_k_c_y_x_host_result.mData)
? 0
: 1;
};
switch(num_dim_spatial)
{
case 3: {
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
return verify_f(ref_conv);
}
case 2: {
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
return verify_f(ref_conv);
}
case 1: {
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
return verify_f(ref_conv);
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/device_unary_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp"
using InDataType = ck::bhalf_t;
using WeiDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t;
using AccDataType = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device::
DeviceUnaryElementwise<AccDataType, WeiDataType, UnaryTypeConvert, 1, 4>;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
using DeviceConvBwdWeightBasePtr =
ck::tensor_operation::device::DeviceConvBwdWeightPtr<InElementOp, WeiElementOp, OutElementOp>;
// clang-format off
template <ck::index_t NumDimSpatial>
using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device::
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType
AccDataType, // WeiDataType
OutDataType, // OutDataType
AccDataType, // AccDataType
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <ck::index_t NumDimSpatial>
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NumDimSpatial>;
template <typename HostTensorB, typename HostTensorA, typename Functor>
void host_elementwise(HostTensorB& B,
const HostTensorA& A,
const std::vector<std::size_t>& shape,
Functor functor)
{
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl;
for(std::size_t n = 0; n < tensor_size; ++n)
{
B.mData[n] = functor(A.mData[n]);
}
}
void print_use_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: is show log (0=no, 1=yes)\n"
<< "arg5: split-k : in this example split-k must be larger than 1\n"
<< "arg6: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::utils::conv::ConvParams params;
int arg_idx = 7;
params.num_dim_spatial_ = num_dim_spatial;
params.N_ = std::stoi(argv[arg_idx++]);
params.K_ = std::stoi(argv[arg_idx++]);
params.C_ = std::stoi(argv[arg_idx++]);
params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
}
params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
}
params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
}
params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
}
params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
}
return params;
}
DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial)
{
switch(num_dim_spatial)
{
case 3: {
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<3>>();
}
case 2: {
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<2>>();
}
case 1: {
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<1>>();
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int num_dim_spatial = 2;
int do_log = 0;
int split_k = 2;
ck::utils::conv::ConvParams params;
params.C_ = 128;
if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]);
}
else if(argc > 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
do_log = std::stoi(argv[4]);
split_k = std::stoi(argv[5]);
num_dim_spatial = std::stoi(argv[6]);
// check args number
int conv_args = 3 + num_dim_spatial * 6;
int cmdline_nargs = conv_args + 7;
if(cmdline_nargs != argc)
{
print_use_msg();
exit(1);
}
params = parse_conv_params(num_dim_spatial, argv);
}
else if(argc != 1)
{
print_use_msg();
exit(1);
}
if(split_k <= 1)
{
print_use_msg();
exit(1);
}
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.C_)};
input_dims.insert(std::end(input_dims),
std::begin(params.input_spatial_lengths_),
std::end(params.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
static_cast<std::size_t>(params.C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(params.filter_spatial_lengths_),
std::end(params.filter_spatial_lengths_));
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
static_cast<std::size_t>(params.K_)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi(
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x_host_result(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x_device_result(
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
// reset input to zero
wei_device_buf.SetZero();
// do GEMM
auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer();
auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
// alloc work space
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
if(bwd_weight_workspace_size <= 0)
{
print_use_msg();
exit(1);
}
float conv_ave_time = 0.f;
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
wei_work_space_device_buf.SetZero();
conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer());
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
params.N_,
params.C_,
params.K_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / conv_ave_time;
float gb_per_sec = num_btype / 1.E6 / conv_ave_time;
std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
if(do_verification)
{
auto verify_f = [&](const auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x_host_result,
out_n_k_ho_wo,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
if(do_log)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
}
return ck::utils::check_err(wei_k_c_y_x_device_result.mData,
wei_k_c_y_x_host_result.mData)
? 0
: 1;
};
switch(num_dim_spatial)
{
case 3: {
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
verify_f(ref_conv);
break;
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
return 0;
}
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp)
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -28,57 +28,64 @@ using F32 = float; ...@@ -28,57 +28,64 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16;
using BiasDataType = F32;
using D0DataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using ReduceDataType = F32; using R0DataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>; using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using LayerNormOutDataType = F16; using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32; using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; // Layout
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using ALayout = Row;
using CLayout = ck::tensor_layout::gemm::RowMajor; using BLayout = Col;
using D1Layout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ELayout = D1Layout;
using AElementOp = PassThrough;
using BElementOp = PassThrough; // Elementwise op
using CElementOp = ck::tensor_operation::element_wise::Relu; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using D0ElementOp = PassThrough; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using ReduceSumOp = ck::reduce::Add; using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using CDEElementOp = AddReluAdd;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using QsElementOp = ck::Tuple<PassThrough, Square>;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using RsElementOp = ck::Tuple<Div, Div>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
// ReduceOp
using ReduceGlobalMemOps = using R0ThreadReduceOp = ck::reduce::Add;
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using R1ThreadReduceOp = ck::reduce::Add;
ck::InMemoryDataOperationEnum::AtomicAdd>; using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto GemmSpecialization = static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, EDataType,
GemmAccDataType, GemmAccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
...@@ -88,9 +95,9 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; ...@@ -88,9 +95,9 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType, ck::tensor_operation::device::Device5AryElementwise<EDataType,
ReduceDataType, R0DataType,
ReduceDataType, R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType, LayerNormOutDataType,
...@@ -124,41 +131,31 @@ auto f_host_tensor_descriptor2d = ...@@ -124,41 +131,31 @@ auto f_host_tensor_descriptor2d =
} }
}; };
template <typename CDataType,
typename ReduceDataType,
typename AccDataType,
typename BiasDataType,
typename D0DataType,
typename A_functor,
typename B_functor,
typename C_functor,
typename C1_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<BiasDataType>& bias_n, const Tensor<D0DataType>& bias_n,
const Tensor<D0DataType>& c1_m_n, const Tensor<D1DataType>& d1_m_n,
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
A_functor a_element_op, AElementOp a_element_op,
B_functor b_element_op, BElementOp b_element_op,
C_functor c_element_op, CDEElementOp cde_element_op,
C1_functor c1_element_op,
int M, int M,
int N) int N)
{ {
int StrideC = N; int StrideE = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N}; auto averageOpInst = Div{N};
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -166,38 +163,32 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -166,38 +163,32 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType acc = auto acc = ck::type_convert<GemmAccDataType>(e_m_n(m, n));
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n)); cde_element_op(e_m_n(m, n), acc, bias_n(n), d1_m_n(m, n));
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc);
c1_element_op(c1, c1);
acc += c1;
c_m_n(m, n) = static_cast<CDataType>(acc);
} }
// reduce_mean and reduce_square_mean // reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{}; auto r0Op = R0ThreadReduceOp{};
auto r1Op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>(); auto mean_acc = r0Op.GetIdentityValue<ReduceAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>(); auto mean_square_acc = r1Op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType c_val = ck::type_convert<AccDataType>(c_m_n(m, n)); auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n(m, n));
AccDataType square_c_val = 0; ReduceAccDataType square_e_val = 0;
UnarySquareElementOp{}(square_c_val, c_val); Square{}(square_e_val, e_val);
reduceSumOpInst(mean_acc, c_val); r0Op(mean_acc, e_val);
reduceSumOpInst(square_mean_acc, square_c_val); r1Op(mean_square_acc, square_e_val);
} }
averageOpInst(mean_acc, mean_acc); averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc); averageOpInst(mean_square_acc, mean_square_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc); mean_m(m) = ck::type_convert<R0DataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc); meanSquare_m(m) = ck::type_convert<R1DataType>(mean_square_acc);
} }
// LayerNorm // LayerNorm
...@@ -206,24 +197,25 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -206,24 +197,25 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType out_acc = 0; NormalizeComputeDataType out_acc = 0;
layerNormInst(out_acc, layerNormInst(out_acc,
static_cast<AccDataType>(c_m_n(m, n)), ck::type_convert<NormalizeComputeDataType>(e_m_n(m, n)),
static_cast<AccDataType>(mean_m(m)), ck::type_convert<NormalizeComputeDataType>(mean_m(m)),
static_cast<AccDataType>(meanSquare_m(m)), ck::type_convert<NormalizeComputeDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)), ck::type_convert<NormalizeComputeDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n))); ck::type_convert<NormalizeComputeDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc); out_m_n(m, n) = ck::type_convert<LayerNormOutDataType>(out_acc);
} }
} }
} }
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename BiasDataType,
typename D0DataType, typename D0DataType,
typename ReduceDataType, typename D1DataType,
typename R0DataType,
typename R1DataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename NormalizeDataType> typename NormalizeDataType>
...@@ -231,12 +223,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, ...@@ -231,12 +223,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + sizeof(EDataType) * M * N + sizeof(D0DataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + sizeof(D0DataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M; sizeof(R1DataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + std::size_t normalize_num_byte = sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + sizeof(R1DataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
...@@ -259,98 +251,90 @@ int main() ...@@ -259,98 +251,90 @@ int main()
ck::index_t StrideA = 1024; ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024; ck::index_t StrideD0 = 0;
ck::index_t StrideD0 = 1024; ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<D0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1)); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, ELayout{}));
Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> r0_Mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n( Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
bias_n.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1, 1}); bias_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-5, 5}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-5, 5});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1}); gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1}); beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(D0DataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) *
reduceMeanSquare_m.mDesc.GetElementSpace()); r1_MeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace()); layerNorm_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data()); bias_device_buf.ToDevice(bias_n.mData.data());
d0_device_buf.ToDevice(c1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data()); gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
auto d_element_op = D0ElementOp{}; auto qs_element_op = QsElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; auto rs_element_op = RsElementOp{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square // Prepare GEMM, mean, mean_square
auto gemmReduce = DeviceGemmBiasAddReduceInstance{}; auto gemmReduce = DeviceOpInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker(); auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), auto gemmReduce_argument = gemmReduce.MakeArgument(
b_device_buf.GetDeviceBuffer(), a_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()}, {bias_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
p_reduces, {r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()},
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, {StrideD0, StrideD1},
{StrideD0}, StrideE,
gemm_element_ops, a_element_op,
{&d_element_op}, b_element_op,
reduce_in_element_ops, cde_element_op,
reduce_out_element_ops); qs_element_op,
rs_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! this device_op instance does not support this problem");
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
} }
reduceMean_device_buf.SetZero(); // init reducetion buffer to 0
reduceMeanSquare_device_buf.SetZero(); r0_Mean_device_buf.SetZero();
r1_MeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(), std::array<const void*, 5> input = {e_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(), r0_Mean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()}; beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
...@@ -360,12 +344,12 @@ int main() ...@@ -360,12 +344,12 @@ int main()
auto normalize_argument = normalize.MakeArgument(input, auto normalize_argument = normalize.MakeArgument(input,
output, output,
{M, N}, {M, N},
{StrideC, 1}, {StrideE, 1},
{1, 0}, {1, 0},
{1, 0}, {1, 0},
{0, 1}, {0, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideE, 1},
NormalizeFunctor{}); NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument)) if(!normalize.IsSupportedArgument(normalize_argument))
...@@ -382,21 +366,20 @@ int main() ...@@ -382,21 +366,20 @@ int main()
{ {
// verification // verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n( Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n, host_gemm_layernorm(host_layerNorm_m_n,
a_m_k, a_m_k,
b_k_n, b_k_n,
bias_n, bias_n,
c1_m_n, d1_m_n,
gamma_n, gamma_n,
beta_n, beta_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
d_element_op, M,
M, N);
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData, pass &= ck::utils::check_err(layerNorm_m_n.mData,
...@@ -418,10 +401,11 @@ int main() ...@@ -418,10 +401,11 @@ int main()
if(time_kernel) if(time_kernel)
DumpGemmLayerNormPerf<ADataType, DumpGemmLayerNormPerf<ADataType,
BDataType, BDataType,
CDataType, EDataType,
BiasDataType,
D0DataType, D0DataType,
ReduceDataType, D1DataType,
R0DataType,
R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType>( LayerNormOutDataType>(
......
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -28,65 +28,73 @@ using F32 = float; ...@@ -28,65 +28,73 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using ReduceDataType = F32; using R0DataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>; using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using LayerNormOutDataType = F16; using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32; using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; // Layout
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using ALayout = Row;
using CLayout = ck::tensor_layout::gemm::RowMajor; using BLayout = Col;
using D1Layout = Row;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using ELayout = D1Layout;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; // Elementwise op
using ReduceSumOp = ck::reduce::Add; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using BElementOp = PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using CDEElementOp = PassThrough;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using QsElementOp = ck::Tuple<PassThrough, Square>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using RsElementOp = ck::Tuple<Div, Div>;
using ReduceGlobalMemOps = // ReduceOp
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using R0ThreadReduceOp = ck::reduce::Add;
ck::InMemoryDataOperationEnum::AtomicAdd>; using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, EDataType,
GemmAccDataType, GemmAccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; PassThrough>;
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType, ck::tensor_operation::device::Device5AryElementwise<EDataType,
ReduceDataType, R0DataType,
ReduceDataType, R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType, LayerNormOutDataType,
...@@ -120,60 +128,54 @@ auto f_host_tensor_descriptor2d = ...@@ -120,60 +128,54 @@ auto f_host_tensor_descriptor2d =
} }
}; };
template <typename CDataType,
typename ReduceDataType,
typename A_functor,
typename B_functor,
typename C_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
A_functor a_element_op, AElementOp a_element_op,
B_functor b_element_op, BElementOp b_element_op,
C_functor c_element_op, CDEElementOp c_element_op,
int M, int M,
int N) int N)
{ {
using out_type = ck::remove_reference_t<decltype(out_m_n(0, 0))>;
int StrideC = N; int StrideE = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N}; auto averageOpInst = Div{N};
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op); ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// reduce_mean and reduce_square_mean // reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{}; auto r0Op = R0ThreadReduceOp{};
auto r1Op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>(); auto mean_acc = r0Op.GetIdentityValue<ReduceAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>(); auto mean_square_acc = r1Op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n)); auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n(m, n));
auto square_c_val = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>(); ReduceAccDataType square_e_val = 0;
Square{}(square_e_val, e_val);
UnarySquareElementOp{}(square_c_val, c_val); r0Op(mean_acc, e_val);
r1Op(mean_square_acc, square_e_val);
reduceSumOpInst(mean_acc, c_val);
reduceSumOpInst(square_mean_acc, square_c_val);
} }
averageOpInst(mean_acc, mean_acc); averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc); averageOpInst(mean_square_acc, mean_square_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc); mean_m(m) = ck::type_convert<R0DataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc); meanSquare_m(m) = ck::type_convert<R1DataType>(mean_square_acc);
} }
// LayerNorm // LayerNorm
...@@ -182,22 +184,23 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -182,22 +184,23 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float out_f32 = 0; NormalizeComputeDataType out_acc = 0;
layerNormInst(out_f32, layerNormInst(out_acc,
static_cast<float>(c_m_n(m, n)), ck::type_convert<NormalizeComputeDataType>(e_m_n(m, n)),
static_cast<float>(mean_m(m)), ck::type_convert<NormalizeComputeDataType>(mean_m(m)),
static_cast<float>(meanSquare_m(m)), ck::type_convert<NormalizeComputeDataType>(meanSquare_m(m)),
static_cast<float>(gamma_n(n)), ck::type_convert<NormalizeComputeDataType>(gamma_n(n)),
static_cast<float>(beta_n(n))); ck::type_convert<NormalizeComputeDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<out_type>(out_f32); out_m_n(m, n) = ck::type_convert<LayerNormOutDataType>(out_acc);
} }
} }
} }
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename ReduceDataType, typename R0DataType,
typename R1DataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename NormalizeDataType> typename NormalizeDataType>
...@@ -205,11 +208,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, ...@@ -205,11 +208,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M; sizeof(R1DataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + std::size_t normalize_num_btye = sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + sizeof(R1DataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
...@@ -232,73 +235,66 @@ int main() ...@@ -232,73 +235,66 @@ int main()
ck::index_t StrideA = 1024; ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024; ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> r0_Mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n( Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1}); gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1}); beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) *
reduceMeanSquare_m.mDesc.GetElementSpace()); r1_MeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace()); layerNorm_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data()); gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square // Prepare GEMM, mean, mean_square
auto gemmReduce = DeviceGemmReduceInstance{}; auto gemmReduce = DeviceOpInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker(); auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), auto gemmReduce_argument = gemmReduce.MakeArgument(
b_device_buf.GetDeviceBuffer(), a_device_buf.GetDeviceBuffer(),
nullptr, b_device_buf.GetDeviceBuffer(),
{}, {},
c_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
p_reduces, {r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()},
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, {},
{}, StrideE,
gemm_element_ops, a_element_op,
{}, b_element_op,
reduce_in_element_ops, cde_element_op,
reduce_out_element_ops); qs_element_op,
rs_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{ {
...@@ -307,13 +303,13 @@ int main() ...@@ -307,13 +303,13 @@ int main()
"not support this GEMM problem"); "not support this GEMM problem");
} }
reduceMean_device_buf.SetZero(); r0_Mean_device_buf.SetZero();
reduceMeanSquare_device_buf.SetZero(); r1_MeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(), std::array<const void*, 5> input = {e_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(), r0_Mean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()}; beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
...@@ -323,12 +319,12 @@ int main() ...@@ -323,12 +319,12 @@ int main()
auto normalize_argument = normalize.MakeArgument(input, auto normalize_argument = normalize.MakeArgument(input,
output, output,
{M, N}, {M, N},
{StrideC, 1}, {StrideE, 1},
{1, 0}, {1, 0},
{1, 0}, {1, 0},
{0, 1}, {0, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideE, 1},
NormalizeFunctor{}); NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument)) if(!normalize.IsSupportedArgument(normalize_argument))
...@@ -345,18 +341,18 @@ int main() ...@@ -345,18 +341,18 @@ int main()
{ {
// verification // verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n( Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n, host_gemm_layernorm(host_layerNorm_m_n,
a_m_k, a_m_k,
b_k_n, b_k_n,
gamma_n, gamma_n,
beta_n, beta_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
M, M,
N); N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData, pass &= ck::utils::check_err(layerNorm_m_n.mData,
...@@ -378,8 +374,9 @@ int main() ...@@ -378,8 +374,9 @@ int main()
if(time_kernel) if(time_kernel)
DumpGemmLayerNormPerf<ADataType, DumpGemmLayerNormPerf<ADataType,
BDataType, BDataType,
CDataType, EDataType,
ReduceDataType, R0DataType,
R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType>( LayerNormOutDataType>(
......
...@@ -4,83 +4,83 @@ ...@@ -4,83 +4,83 @@
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/host_tensor/device_memory.hpp" // This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel
#include "ck/library/host_tensor/host_tensor.hpp" //
#include "ck/library/host_tensor/host_tensor_generator.hpp" // The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" // together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
#include "ck/library/utility/check_err.hpp" // a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
//
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using GemmAccDataType = F32; using C0DataType = F16;
using ReduceAccDataType = F32; using AccDataType = F32;
using ReduceDataType = F64; using CShuffleDataType = F16;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; struct Relu
using BElementOp = ck::tensor_operation::element_wise::PassThrough; {
using CElementOp = ck::tensor_operation::element_wise::PassThrough; template <typename OutT, typename InT>
using ReduceOps = ck::Tuple<ck::reduce::Max>; __host__ __device__ void operator()(OutT& y, const InT& x) const
using ReduceElementOps = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>; {
using ReduceGlobalMemOps = y = x > 0 ? x : 0;
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>; }
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
// Elementwise operation that operates on the output of matrix multiplication
// i.e., AccElementOp(A * B + bias)
using AccElementOp = Relu;
// Elementwise operation that operates on the output of layer normalization
using CElementOp = Relu;
static constexpr auto GemmSpecialization = static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
GemmAccDataType, C0DataType,
AElementOp, AccDataType,
BElementOp, AElementOp,
CElementOp>; BElementOp,
AccElementOp,
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType> CElementOp>;
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
std::cout << "gemm + reduceMax Perf: " << gemm_reduce_time << " ms, " << tflops << " TFlops, "
<< gemm_gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -90,12 +90,12 @@ int main(int argc, char* argv[]) ...@@ -90,12 +90,12 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
ck::index_t N = 4096; ck::index_t N = 128;
ck::index_t K = 4096; ck::index_t K = 4096;
ck::index_t StrideA = 4096; ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096; ck::index_t StrideC = 128;
if(argc == 1) if(argc == 1)
{ {
...@@ -125,7 +125,7 @@ int main(int argc, char* argv[]) ...@@ -125,7 +125,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -146,19 +146,21 @@ int main(int argc, char* argv[]) ...@@ -146,19 +146,21 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce_m_device_result( Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); Tensor<C0DataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<C0DataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<C0DataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl; std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl;
std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl;
std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl;
std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -167,49 +169,62 @@ int main(int argc, char* argv[]) ...@@ -167,49 +169,62 @@ int main(int argc, char* argv[])
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); c0_m_n_add.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 2});
DeviceMem reduce_device_buf(sizeof(ReduceDataType) * c0_n_beta.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 5});
reduce_m_device_result.mDesc.GetElementSpace()); c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpaceSize());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpaceSize());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpaceSize());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c0_bias_buf.ToDevice(c0_n_bias.mData.data());
c0_add_buf.ToDevice(c0_m_n_add.mData.data());
c0_gamma_buf.ToDevice(c0_n_gamma.mData.data());
c0_beta_buf.ToDevice(c0_n_beta.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto acc_element_op = AccElementOp{};
auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}]; auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
std::array<void*, 1> reduce_element_ops = {&reduce_element_op};
std::array<void*, 1> p_reduces = {reduce_device_buf.GetDeviceBuffer()};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
b_device_buf.GetDeviceBuffer(), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
nullptr, static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
{}, static_cast<C0DataType*>(c0_add_buf.GetDeviceBuffer()),
c_device_buf.GetDeviceBuffer(), static_cast<C0DataType*>(c0_bias_buf.GetDeviceBuffer()),
p_reduces, static_cast<C0DataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_beta_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
{}, a_element_op,
gemm_element_ops, b_element_op,
{}, acc_element_op,
reduce_element_ops, c_element_op);
reduce_element_ops);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -218,59 +233,57 @@ int main(int argc, char* argv[]) ...@@ -218,59 +233,57 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// [CAUSION]: launch_and_time_kernel will not initialize D. float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// If we evaluate kernel multiple time but without initialize D. Verification will fail
reduce_device_buf.SetValue(ck::NumericLimits<ReduceDataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true; // extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div,
// excluding reduction steps
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N;
// extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN)
std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
reduce_device_buf.FromDevice(reduce_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); b_k_n,
c_m_n_host_result,
c0_n_bias,
c0_m_n_add,
c0_n_gamma,
c0_n_beta,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto reduce_op = ReduceOps{}[ck::Number<0>{}]; if constexpr(std::is_same<CShuffleDataType, F32>::value)
for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue<ReduceAccDataType>(); pass &= ck::utils::check_err(
c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c");
for(int n = 0; n < N; ++n) }
{ else if constexpr(std::is_same<CShuffleDataType, F16>::value)
ReduceAccDataType curr_val = {
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n)); pass &= ck::utils::check_err(c_m_n_device_result.mData,
reduce_op(reduce_acc, curr_val); c_m_n_host_result.mData,
}; "Error: Incorrect results c",
1e-2,
reduce_m_host_result(m) = reduce_acc; 1e-2);
} }
pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(reduce_m_device_result.mData,
reduce_m_host_result.mData,
"Error: Incorrect results d",
1e-3,
1e-3);
}
if(time_kernel)
{
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(
gemm_reduceMax_ave_time, M, N, K);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
} }
add_custom_target(example_cgemm_xdl)
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "cgemm_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
using ADataType = BF16;
using BDataType = BF16;
using CDataType = BF16;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using ReferenceCGemmInstance = ck::tensor_operation::host::
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
// clang-format off
using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle
<ALayout, // typename ALayout
BLayout, // typename BLayout
CLayout, // typename CLayout
ADataType, // typename ADataType
BDataType, // typename BDataType
CDataType, // typename CDataType
AccDataType, // typename GemmAccDataType
CDataType, // typename CShuffleDataType
PassThrough, // typename AElementwiseOperation
PassThrough, // typename BElementwiseOperation
PassThrough, // typename CElementwiseOperation
GemmDefault, // GemmSpecialization GemmSpec
1, // index_t NumGemmKPrefetchStage
256, // index_t BlockSize
256, // index_t MPerBlock
128, // index_t NPerBlock
32, // index_t KPerBlock
8, // index_t AK1
8, // index_t BK1
32, // index_t MPerXDL
32, // index_t NPerXDL
4, // index_t MXdlPerWave
2, // index_t NXdlPerWave
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
2, // index_t ABlockTransferSrcVectorDim
8, // index_t ABlockTransferSrcScalarPerVector
8, // index_t ABlockTransferDstScalarPerVector_AK1
1, // index_t ABlockLdsExtraM
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
2, // index_t BBlockTransferSrcVectorDim
8, // index_t BBlockTransferSrcScalarPerVector
8, // index_t BBlockTransferDstScalarPerVector_BK1
1, // index_t BBlockLdsExtraN
1, // index_t CShuffleMXdlPerWavePerShuffle
1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// CGEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 416;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n"
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"
<< std::endl;
exit(0);
}
return run_cgemm_xdl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
PassThrough,
PassThrough,
PassThrough,
DeviceCGemmInstance,
ReferenceCGemmInstance>(
M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using BF16 = ck::bhalf_t;
using INT8 = std::int8_t;
using INT32 = std::int32_t;
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DeviceCGemmInstance,
typename ReferenceCGemmInstance>
int run_cgemm_xdl(ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
bool do_verification,
int init_method,
bool time_kernel)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl;
std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl;
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k_real.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
a_m_k_imag.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n_real.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b_k_n_imag.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
default:
a_m_k_real.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
a_m_k_imag.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n_real.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_k_n_imag.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
auto cgemm = DeviceCGemmInstance{};
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpaceSize());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) *
c_m_n_real_device_result.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpaceSize());
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data());
b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data());
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
// do GEMM
auto invoker = cgemm.MakeInvoker();
auto argument =
cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!cgemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_cgemm with the specified compilation parameters does "
"not support this CGEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(8) * M * N * K;
std::size_t num_btype =
std::size_t(2) *
(sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< cgemm.GetTypeString() << std::endl;
c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data());
c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data());
if(do_verification)
{
Tensor<CDataType> c_m_n_real_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_cgemm = ReferenceCGemmInstance{};
auto ref_invoker = ref_cgemm.MakeInvoker();
auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real_host_result,
c_m_n_imag_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
bool result = true;
result = ck::utils::check_err(c_m_n_real_device_result.mData,
c_m_n_real_host_result.mData,
"Verification error: incorrect results in real part!",
1e-2f,
1e-1f);
result = result &&
ck::utils::check_err(c_m_n_imag_device_result.mData,
c_m_n_imag_host_result.mData,
"Verification error: incorrect results in imaginary part!",
1e-2f,
1e-1f);
return result ? 0 : 1;
}
return 0;
}
...@@ -2,43 +2,30 @@ ...@@ -2,43 +2,30 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "cgemm_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/utility/check_err.hpp" using ADataType = F16;
#include "ck/library/host_tensor/device_memory.hpp" using BDataType = F16;
#include "ck/library/host_tensor/host_tensor.hpp" using CDataType = F16;
#include "ck/library/host_tensor/host_tensor_generator.hpp" using AccDataType = F32;
#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" using CShuffleDataType = F32;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using ReferenceCGemmInstance = ck::tensor_operation::host::
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
// clang-format off // clang-format off
using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle
<ALayout, // typename ALayout <ALayout, // typename ALayout
...@@ -48,7 +35,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ ...@@ -48,7 +35,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
BDataType, // typename BDataType BDataType, // typename BDataType
CDataType, // typename CDataType CDataType, // typename CDataType
AccDataType, // typename GemmAccDataType AccDataType, // typename GemmAccDataType
CDataType, // typename CShuffleDataType CShuffleDataType, // typename CShuffleDataType
PassThrough, // typename AElementwiseOperation PassThrough, // typename AElementwiseOperation
PassThrough, // typename BElementwiseOperation PassThrough, // typename BElementwiseOperation
PassThrough, // typename CElementwiseOperation PassThrough, // typename CElementwiseOperation
...@@ -84,9 +71,6 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ ...@@ -84,9 +71,6 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on // clang-format on
using ReferenceCGemmInstance = ck::tensor_operation::host::
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -124,155 +108,24 @@ int main(int argc, char* argv[]) ...@@ -124,155 +108,24 @@ int main(int argc, char* argv[])
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); std::cout << "arg1: verification (0=no, 1=yes)\n"
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
printf("arg3: run kernel # of times (>1)\n"); << "arg3: run kernel # of times (>1)\n"
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"
<< std::endl;
exit(0); exit(0);
} }
auto f_host_tensor_descriptor = return run_cgemm_xdl<ADataType,
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { BDataType,
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) CDataType,
{ ALayout,
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), BLayout,
std::vector<std::size_t>({stride, 1})); CLayout,
} PassThrough,
else PassThrough,
{ PassThrough,
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), DeviceCGemmInstance,
std::vector<std::size_t>({1, stride})); ReferenceCGemmInstance>(
} M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel);
};
Tensor<ADataType> a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl;
std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl;
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k_real.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
a_m_k_imag.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n_real.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b_k_n_imag.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
default:
a_m_k_real.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
a_m_k_imag.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n_real.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_k_n_imag.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
auto cgemm = DeviceCGemmInstance{};
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace());
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) *
c_m_n_real_device_result.mDesc.GetElementSpace());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpace());
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data());
b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data());
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
// do GEMM
auto invoker = cgemm.MakeInvoker();
auto argument =
cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!cgemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_cgemm with the specified compilation parameters does "
"not support this CGEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(8) * M * N * K;
std::size_t num_btype =
std::size_t(2) *
(sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< cgemm.GetTypeString() << std::endl;
c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data());
c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data());
if(do_verification)
{
Tensor<CDataType> c_m_n_real_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_cgemm = ReferenceCGemmInstance{};
auto ref_invoker = ref_cgemm.MakeInvoker();
auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real_host_result,
c_m_n_imag_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_real_device_result.mData,
c_m_n_real_host_result.mData,
"Verification error: incorrect results in real part!",
1e-2f,
1e-1f);
ck::utils::check_err(c_m_n_imag_device_result.mData,
c_m_n_imag_host_result.mData,
"Verification error: incorrect results in imaginary part!",
1e-2f,
1e-1f);
}
return 0;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "cgemm_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
using ADataType = F32;
using BDataType = F32;
using CDataType = F32;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using ReferenceCGemmInstance = ck::tensor_operation::host::
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
// clang-format off
using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle
<ALayout, // typename ALayout
BLayout, // typename BLayout
CLayout, // typename CLayout
ADataType, // typename ADataType
BDataType, // typename BDataType
CDataType, // typename CDataType
AccDataType, // typename GemmAccDataType
CDataType, // typename CShuffleDataType
PassThrough, // typename AElementwiseOperation
PassThrough, // typename BElementwiseOperation
PassThrough, // typename CElementwiseOperation
GemmDefault, // GemmSpecialization GemmSpec
1, // index_t NumGemmKPrefetchStage
256, // index_t BlockSize
256, // index_t MPerBlock
128, // index_t NPerBlock
16, // index_t KPerBlock
4, // index_t AK1
4, // index_t BK1
32, // index_t MPerXDL
32, // index_t NPerXDL
4, // index_t MXdlPerWave
2, // index_t NXdlPerWave
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
2, // index_t ABlockTransferSrcVectorDim
4, // index_t ABlockTransferSrcScalarPerVector
4, // index_t ABlockTransferDstScalarPerVector_AK1
1, // index_t ABlockLdsExtraM
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
2, // index_t BBlockTransferSrcVectorDim
4, // index_t BBlockTransferSrcScalarPerVector
4, // index_t BBlockTransferDstScalarPerVector_BK1
1, // index_t BBlockLdsExtraN
1, // index_t CShuffleMXdlPerWavePerShuffle
1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// CGEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: run kernel # of times (>1)\n"
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"
<< std::endl;
exit(0);
}
return run_cgemm_xdl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
PassThrough,
PassThrough,
PassThrough,
DeviceCGemmInstance,
ReferenceCGemmInstance>(
M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment