Commit 8a8dca0a authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fixes for examples

parent 83360328
......@@ -32,17 +32,14 @@ struct SimpleDeviceMem
};
template <ck::index_t NumDimSpatial>
std::size_t GetFlops(ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& output_lengths,
std::size_t GetFlops(const std::array<ck::index_t, NumDimSpatial>& output_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
{
constexpr index_t spatial_offset = 3;
constexpr ck::index_t spatial_offset = 3;
const auto C = filter_lengths[2];
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G * N * K * C *
std::accumulate(std::begin(output_lengths) + spatial_offset,
return static_cast<std::size_t>(2) * C *
std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
......@@ -53,45 +50,30 @@ std::size_t GetFlops(ck::index_t G,
}
template <typename InDataType, ck::index_t NumDimSpatial>
std::size_t GetInputByte(ck::index_t G,
ck::index_t N,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_lengths)
std::size_t GetInputByte(const std::array<ck::index_t, NumDimSpatial>& input_lengths)
{
constexpr index_t spatial_offset = 3;
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * (G * N * C *
std::accumulate(std::begin(input_lengths) + spatial_offset,
return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths),
std::end(input_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()));
}
template <typename WeiDataType, ck::index_t NumDimSpatial>
std::size_t GetWeightByte(ck::index_t G,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
std::size_t GetWeightByte(const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
{
constexpr index_t spatial_offset = 3;
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * (G * K * C *
std::accumulate(std::begin(filter_lengths) + spatial_offset,
return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths),
std::end(filter_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()));
}
template <typename OutDataType, ck::index_t NumDimSpatial>
std::size_t GetOutputByte(ck::index_t G,
ck::index_t N,
ck::index_t K,
const std::array<ck::index_t, NumDimSpatial>& output_lengths)
std::size_t GetOutputByte(const std::array<ck::index_t, NumDimSpatial>& output_lengths)
{
constexpr index_t spatial_offset = 3;
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G * N * K *
std::accumulate(std::begin(output_lengths) + spatial_offset,
return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
......@@ -105,15 +87,11 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout,
typename OutLayout>
bool run_grouped_conv_bwd_weight(
const ck::index_t G,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NumDimSpatial + 3>& input_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
......@@ -122,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
{
ck::index_t split_k = 2;
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(input_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(filter_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(output_lengths));
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths));
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout,
......@@ -148,9 +126,9 @@ bool run_grouped_conv_bwd_weight(
float best_gb_per_sec = 0;
float best_tflops = 0;
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NumDimSpatial + 3> b_g_k_c_xs_lengths{};
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
......@@ -182,11 +160,10 @@ bool run_grouped_conv_bwd_weight(
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = GetFlops<NumDimSpatial>(G, N, K, C, output_lengths, filter_lengths);
std::size_t num_bytes =
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_lengths) +
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_lengths) +
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_lengths);
std::size_t flop = GetFlops<NumDimSpatial + 3>(output_lengths, filter_lengths);
std::size_t num_bytes = GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths) +
GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths) +
GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
......
......@@ -22,9 +22,9 @@ static constexpr ck::index_t C = 192;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, X};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, 1, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, 1, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, 1, K};
......@@ -41,15 +41,11 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
OutLayout>(input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -25,9 +25,9 @@ static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
......@@ -47,15 +47,11 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
OutLayout>(input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -28,9 +28,9 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
......@@ -50,15 +50,11 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
OutLayout>(input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -28,9 +28,9 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
......@@ -50,15 +50,11 @@ int main()
OutDataType,
InLayout,
WeiLayout,
OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
OutLayout>(input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
......@@ -865,20 +865,20 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths),
end(output_spatial_lengths),
std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths),
end(input_spatial_lengths),
std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths),
end(filter_spatial_lengths),
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
}
......
......@@ -70,9 +70,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
......@@ -83,11 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
......@@ -99,15 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.C_,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment