Commit ae1b4ee6 authored by Chao Liu's avatar Chao Liu
Browse files

add bias

parent cf95b944
...@@ -114,13 +114,23 @@ int run_conv_fwd(bool do_verification, ...@@ -114,13 +114,23 @@ int run_conv_fwd(bool do_verification,
const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param); const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param);
const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param); const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param);
// hacky, hardcoded for 2d NHWK
const auto bias_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)},
std::vector<std::size_t>{0, 0, 0, 1});
Tensor<InDataType> in(in_desc); Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei(wei_desc); Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> bias(bias_desc);
Tensor<OutDataType> out_host(out_desc); Tensor<OutDataType> out_host(out_desc);
Tensor<OutDataType> out_device(out_desc); Tensor<OutDataType> out_device(out_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl; std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method) switch(init_method)
...@@ -129,23 +139,28 @@ int run_conv_fwd(bool do_verification, ...@@ -129,23 +139,28 @@ int run_conv_fwd(bool do_verification,
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpace());
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order // tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_n_c_wis_desc = in_desc; HostTensorDescriptor in_n_c_wis_desc = in_desc;
HostTensorDescriptor wei_k_c_xs_desc = wei_desc; HostTensorDescriptor wei_k_c_xs_desc = wei_desc;
HostTensorDescriptor out_n_k_wos_desc = out_desc; HostTensorDescriptor bias_n_k_wos_desc = bias_desc;
HostTensorDescriptor out_n_k_wos_desc = out_desc;
// input // input
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>) if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
...@@ -186,22 +201,33 @@ int run_conv_fwd(bool do_verification, ...@@ -186,22 +201,33 @@ int run_conv_fwd(bool do_verification,
{ {
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old( out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1}); out_desc, std::vector<std::size_t>{0, 2, 1});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 2, 1});
} }
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>) else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{ {
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old( out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2}); out_desc, std::vector<std::size_t>{0, 3, 1, 2});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 3, 1, 2});
} }
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>) else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{ {
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old( out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3}); out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
} }
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{}; std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{}; std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{}; std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{}; std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
...@@ -215,6 +241,8 @@ int run_conv_fwd(bool do_verification, ...@@ -215,6 +241,8 @@ int run_conv_fwd(bool do_verification,
copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides); copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides);
copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths); copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths);
copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides); copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides);
copy(bias_n_k_wos_desc.GetLengths(), d_n_k_wos_lengths);
copy(bias_n_k_wos_desc.GetStrides(), d_n_k_wos_strides);
copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths); copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths);
copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides); copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
...@@ -225,25 +253,26 @@ int run_conv_fwd(bool do_verification, ...@@ -225,25 +253,26 @@ int run_conv_fwd(bool do_verification,
// do GEMM // do GEMM
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(
wei_device_buf.GetDeviceBuffer(), in_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{}, wei_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(), std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
a_n_c_wis_lengths, out_device_buf.GetDeviceBuffer(),
a_n_c_wis_strides, a_n_c_wis_lengths,
b_k_c_xs_lengths, a_n_c_wis_strides,
b_k_c_xs_strides, b_k_c_xs_lengths,
std::array<std::array<ck::index_t, NDimSpatial + 2>, 0>{{}}, b_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 2>, 0>{{}}, std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_lengths}},
e_n_k_wos_lengths, std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_strides}},
e_n_k_wos_strides, e_n_k_wos_lengths,
conv_filter_strides, e_n_k_wos_strides,
conv_filter_dilations, conv_filter_strides,
input_left_pads, conv_filter_dilations,
input_right_pads, input_left_pads,
in_element_op, input_right_pads,
wei_element_op, in_element_op,
out_element_op); wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
...@@ -264,6 +293,10 @@ int run_conv_fwd(bool do_verification, ...@@ -264,6 +293,10 @@ int run_conv_fwd(bool do_verification,
if(do_verification) if(do_verification)
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<OutDataType> c_host(out_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -273,26 +306,41 @@ int run_conv_fwd(bool do_verification, ...@@ -273,26 +306,41 @@ int run_conv_fwd(bool do_verification,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>(); PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in, auto ref_argument = ref_conv.MakeArgument(in,
wei, wei,
out_host, c_host,
conv_param.conv_filter_strides_, conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_, conv_param.conv_filter_dilations_,
conv_param.input_left_pads_, conv_param.input_left_pads_,
conv_param.input_right_pads_, conv_param.input_right_pads_,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int n = 0; n < out_host.mDesc.GetLengths()[0]; n++)
{
for(int ho = 0; ho < out_host.mDesc.GetLengths()[1]; ho++)
{
for(int wo = 0; wo < out_host.mDesc.GetLengths()[2]; wo++)
{
for(int k = 0; k < out_host.mDesc.GetLengths()[3]; k++)
{
out_element_op(
out_host(n, ho, wo, k), c_host(n, ho, wo, k), bias(n, ho, wo, k));
}
}
}
}
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err( return ck::utils::check_err(
out_host.mData, out_device.mData, "Error: incorrect results!", 1e-5f, 1e-4f) out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0 ? 0
: 1; : 1;
} }
......
...@@ -16,7 +16,8 @@ using S = ck::Sequence<Is...>; ...@@ -16,7 +16,8 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert; // using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
#if 0 #if 0
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
...@@ -60,6 +61,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -60,6 +61,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
1>; // CThreadTransferDstScalarPerVector 1>; // CThreadTransferDstScalarPerVector
#else #else
using CShuffleDataType = ck::half_t; using CShuffleDataType = ck::half_t;
using DDataType = ck::half_t;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -77,7 +79,10 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti ...@@ -77,7 +79,10 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti
ck::Tuple<ck::tensor_layout::convolution::KXC, ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>, ck::tensor_layout::convolution::KZYXC>>,
ck::Tuple<>, ck::Tuple<ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NW_K,
ck::tensor_layout::convolution::NHW_K,
ck::tensor_layout::convolution::NDHW_K>>>,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK, ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK, ck::tensor_layout::convolution::NHWK,
...@@ -86,7 +91,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti ...@@ -86,7 +91,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti
WeiDataType, WeiDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<>, ck::Tuple<DDataType>,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
......
...@@ -565,10 +565,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -565,10 +565,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
// supported layout:
// KXC, K_XC
// KYXC, K_YXC
// KZYXC, K_ZYXC
template <typename BLay, template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::KXC> || typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::KXC> ||
is_same_v<BLay, tensor_layout::convolution::KYXC> || is_same_v<BLay, tensor_layout::convolution::KYXC> ||
...@@ -625,10 +621,57 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -625,10 +621,57 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return out_gemmm_gemmn_grid_desc; return out_gemmm_gemmn_grid_desc;
} }
using AGridDesc_M_K = remove_cvref_t<decltype( template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::NW_K> ||
is_same_v<ELay, tensor_layout::convolution::NHW_K> ||
is_same_v<ELay, tensor_layout::convolution::NDHW_K>,
bool>::type = false>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides)
{
namespace ctc = ck::tensor_layout::convolution;
const index_t N = e_n_k_wos_lengths[0];
const index_t K = e_n_k_wos_lengths[1];
const index_t WoStride = e_n_k_wos_strides[NDimSpatial + 1];
const index_t GemmMRaw = N * std::accumulate(e_n_k_wos_lengths.begin() + 2,
e_n_k_wos_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const index_t GemmNRaw = K;
const auto out_gemmmraw_gemmnraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(GemmMRaw, GemmNRaw), make_tuple(WoStride, I1));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_n_k_wos_lengths[i],
ds_n_k_wos_strides[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
...@@ -643,7 +686,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -643,7 +686,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_N_K, BGridDesc_N_K,
StaticallyIndexedArray<EGridDesc_M_N, NumDTensor>, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -762,6 +805,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -762,6 +805,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_}; block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_};
// populate pointer and desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_n_k_wos_lengths[i], ds_n_k_wos_strides[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
...@@ -772,22 +827,21 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -772,22 +827,21 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
// populate pointer and desc for Ds ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
static_for<0, NumDTensor, 1>{}([&](auto i) { GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; ds_grid_desc_m_n_);
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
ds_grid_desc_m_n_[i] = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
ds_n_k_wos_lengths[i], ds_n_k_wos_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_[i]);
});
} }
} }
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private: // private:
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -798,14 +852,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -798,14 +852,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// tensor descriptors // tensor descriptors
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
StaticallyIndexedArray<EGridDesc_M_N, NumDTensor> ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
...@@ -841,11 +893,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -841,11 +893,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 1 #if 1
{ arg.Print();
std::cout << "A[M, K]: " << arg.a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << arg.b_grid_desc_n_k_ << std::endl;
std::cout << "E[M, N]: " << arg.e_grid_desc_m_n_ << std::endl;
}
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
...@@ -876,9 +924,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -876,9 +924,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
has_main_loop>; has_main_loop>;
...@@ -921,8 +967,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -921,8 +967,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 1
arg.Print();
#endif
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
int itmp = 0;
printf("itmp %d\n", itmp++);
// check device // check device
if(get_device_name() == "gfx908") if(get_device_name() == "gfx908")
{ {
...@@ -945,6 +998,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -945,6 +998,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("itmp %d\n", itmp++);
// check ConvolutionForwardSpecialization // check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -980,6 +1035,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -980,6 +1035,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
} }
} }
printf("itmp %d\n", itmp++);
// check vector access of A // check vector access of A
if constexpr(is_same_v<ALayout, ctc::NWC> || is_same_v<ALayout, ctc::NHWC> || if constexpr(is_same_v<ALayout, ctc::NWC> || is_same_v<ALayout, ctc::NHWC> ||
is_same_v<ALayout, ctc::NDHWC>) is_same_v<ALayout, ctc::NDHWC>)
...@@ -996,6 +1053,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -996,6 +1053,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("itmp %d\n", itmp++);
// check vector access of B // check vector access of B
if constexpr(is_same_v<BLayout, ctc::KXC> || is_same_v<BLayout, ctc::KYXC> || if constexpr(is_same_v<BLayout, ctc::KXC> || is_same_v<BLayout, ctc::KYXC> ||
is_same_v<BLayout, ctc::KZYXC>) is_same_v<BLayout, ctc::KZYXC>)
...@@ -1012,7 +1071,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1012,7 +1071,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
// FIXME: check vector access of Ds printf("itmp %d\n", itmp++);
// check vector access of Ds
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same_v<DLayout, ctc::NWK> || is_same_v<DLayout, ctc::NHWK> ||
is_same_v<DLayout, ctc::NDHWK> || is_same_v<DLayout, ctc::NW_K> ||
is_same_v<DLayout, ctc::NHW_K> || is_same_v<DLayout, ctc::NDHW_K>)
{
const index_t K = arg.ds_n_k_wos_lengths_[i][1];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
}
}
else
{
valid = false;
}
});
if(!valid)
{
return false;
}
printf("itmp %d\n", itmp++);
// check vector access of E // check vector access of E
if constexpr(is_same_v<ELayout, ctc::NWK> || is_same_v<ELayout, ctc::NHWK> || if constexpr(is_same_v<ELayout, ctc::NWK> || is_same_v<ELayout, ctc::NHWK> ||
...@@ -1030,6 +1119,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1030,6 +1119,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
printf("itmp %d\n", itmp++);
// check Gridwise GEMM // check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
......
...@@ -25,41 +25,45 @@ struct ColumnMajor : public BaseTensorLayout ...@@ -25,41 +25,45 @@ struct ColumnMajor : public BaseTensorLayout
namespace convolution { namespace convolution {
// 1D Conv // input tensor
// packed NWC/NHWC/NDHWC
struct NWC : public BaseTensorLayout struct NWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NWC"; static constexpr const char* name = "NWC";
}; };
struct KXC : public BaseTensorLayout struct NHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "KXC"; static constexpr const char* name = "NHWC";
}; };
struct NWK : public BaseTensorLayout struct NDHWC : public BaseTensorLayout
{ {
static constexpr const char* name = "NWK"; static constexpr const char* name = "NDHWC";
}; };
// input tensor
// packed NCW/NCHW/NCDHW
struct NCW : public BaseTensorLayout struct NCW : public BaseTensorLayout
{ {
static constexpr const char* name = "NCW"; static constexpr const char* name = "NCW";
}; };
struct KCX : public BaseTensorLayout struct NCHW : public BaseTensorLayout
{ {
static constexpr const char* name = "KCX"; static constexpr const char* name = "NCHW";
}; };
struct NKW : public BaseTensorLayout struct NCDHW : public BaseTensorLayout
{ {
static constexpr const char* name = "NKW"; static constexpr const char* name = "NCDHW";
}; };
// 2D Conv // weight tensor
struct NHWC : public BaseTensorLayout // packed KXC/KYXC/KZYXC
struct KXC : public BaseTensorLayout
{ {
static constexpr const char* name = "NHWC"; static constexpr const char* name = "KXC";
}; };
struct KYXC : public BaseTensorLayout struct KYXC : public BaseTensorLayout
...@@ -67,14 +71,16 @@ struct KYXC : public BaseTensorLayout ...@@ -67,14 +71,16 @@ struct KYXC : public BaseTensorLayout
static constexpr const char* name = "KYXC"; static constexpr const char* name = "KYXC";
}; };
struct NHWK : public BaseTensorLayout struct KZYXC : public BaseTensorLayout
{ {
static constexpr const char* name = "NHWK"; static constexpr const char* name = "KZYXC";
}; };
struct NCHW : public BaseTensorLayout // weight tensor
// packed KCX/KCYX/KCZYX
struct KCX : public BaseTensorLayout
{ {
static constexpr const char* name = "NCHW"; static constexpr const char* name = "KCX";
}; };
struct KCYX : public BaseTensorLayout struct KCYX : public BaseTensorLayout
...@@ -82,34 +88,38 @@ struct KCYX : public BaseTensorLayout ...@@ -82,34 +88,38 @@ struct KCYX : public BaseTensorLayout
static constexpr const char* name = "KCYX"; static constexpr const char* name = "KCYX";
}; };
struct NKHW : public BaseTensorLayout struct KCZYX : public BaseTensorLayout
{ {
static constexpr const char* name = "NKHW"; static constexpr const char* name = "KCZYX";
}; };
// 3D Conv // output tensor
struct NDHWC : public BaseTensorLayout // packed NWK/NHWK/NDHWK
struct NWK : public BaseTensorLayout
{ {
static constexpr const char* name = "NDHWC"; static constexpr const char* name = "NWK";
}; };
struct KZYXC : public BaseTensorLayout struct NHWK : public BaseTensorLayout
{ {
static constexpr const char* name = "KZYXC"; static constexpr const char* name = "NHWK";
}; };
struct NDHWK : public BaseTensorLayout struct NDHWK : public BaseTensorLayout
{ {
static constexpr const char* name = "NDHWK"; static constexpr const char* name = "NDHWK";
}; };
struct NCDHW : public BaseTensorLayout
// output tensor
// packed NKW/NKHW/NKDHW
struct NKW : public BaseTensorLayout
{ {
static constexpr const char* name = "NCDHW"; static constexpr const char* name = "NKW";
}; };
struct KCZYX : public BaseTensorLayout struct NKHW : public BaseTensorLayout
{ {
static constexpr const char* name = "KCZYX"; static constexpr const char* name = "NKHW";
}; };
struct NKDHW : public BaseTensorLayout struct NKDHW : public BaseTensorLayout
...@@ -117,6 +127,23 @@ struct NKDHW : public BaseTensorLayout ...@@ -117,6 +127,23 @@ struct NKDHW : public BaseTensorLayout
static constexpr const char* name = "NKDHW"; static constexpr const char* name = "NKDHW";
}; };
// output tensor
// strided layout
struct NW_K : public BaseTensorLayout
{
static constexpr const char* name = "NW_K";
};
struct NHW_K : public BaseTensorLayout
{
static constexpr const char* name = "NHW_K";
};
struct NDHW_K : public BaseTensorLayout
{
static constexpr const char* name = "NDHW_K";
};
} // namespace convolution } // namespace convolution
template < template <
......
...@@ -165,6 +165,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -165,6 +165,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{ {
...@@ -180,6 +181,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -180,6 +181,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{ {
...@@ -195,8 +197,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -195,8 +197,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
__host__ __device__ static constexpr auto // E desc for destination in blockwise copy
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
{ {
const auto M = e_grid_desc_m_n.GetLength(I0); const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1); const auto N = e_grid_desc_m_n.GetLength(I1);
...@@ -214,6 +218,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -214,6 +218,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return e_grid_desc_mblock_mperblock_nblock_nperblock; return e_grid_desc_mblock_mperblock_nblock_nperblock;
} }
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
...@@ -301,8 +318,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -301,8 +318,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 = using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
...@@ -313,24 +332,21 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -313,24 +332,21 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
Run(const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid,
const ABDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid,
DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid,
EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared,
void* __restrict__ p_shared, const AElementwiseOperation& a_element_op,
const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op,
const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op,
const CDEElementwiseOperation& cde_element_op, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const StaticallyIndexedArray<EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
NumDTensor>& const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different e_grid_desc_mblock_mperblock_nblock_nperblock,
// type from E const Block2ETileMap& block_2_etile_map)
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment