Commit e28aa8dd authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Comment fixes

parent dbfe0051
...@@ -63,7 +63,7 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() ...@@ -63,7 +63,7 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo}; std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo};
std::array<ck::index_t, 6> out_strides{ std::array<ck::index_t, 6> out_strides{
C, Do * Ho * Wo * G * C, 1, Ho * Wo * G * C, Wo * G * C, G * C}; K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1}; std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1}; std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1};
......
...@@ -68,7 +68,7 @@ int execute_conv_fwd_scaleadd_ab() ...@@ -68,7 +68,7 @@ int execute_conv_fwd_scaleadd_ab()
K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo}; std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo};
std::array<ck::index_t, 6> out_strides{ std::array<ck::index_t, 6> out_strides{
C, Do * Ho * Wo * G * C, 1, Ho * Wo * G * C, Wo * G * C, G * C}; K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1}; std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1}; std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1};
...@@ -147,11 +147,11 @@ int execute_conv_fwd_scaleadd_ab() ...@@ -147,11 +147,11 @@ int execute_conv_fwd_scaleadd_ab()
{ {
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Z * Y * X +
N * Hi * Wi * G * C + G * K * Y * X * C; N * Di * Hi * Wi * G * C + G * K * Z * Y * X * C;
std::size_t num_bytes = 2 * sizeof(InDataType) * N * Hi * Wi * G * C + std::size_t num_bytes = 2 * sizeof(InDataType) * N * Di * Hi * Wi * G * C +
2 * sizeof(WeiDataType) * G * K * Y * X * C + 2 * sizeof(WeiDataType) * G * K * Z * Y * X * C +
sizeof(OutDataType) * N * Ho * Wo * G * K; sizeof(OutDataType) * N * Do * Ho * Wo * G * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
......
...@@ -9,101 +9,99 @@ namespace ck { ...@@ -9,101 +9,99 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0> template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
static constexpr bool isMultiAB = NumATensor > 1 || NumBTensor > 1; };
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs,
index_t BatchStrideB, Array<ck::index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE) index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideB), BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE) BatchStrideE_(BatchStrideE)
{ {
if constexpr(isMultiAB)
{
static_assert("Invalid constructor for multiple A or B");
}
} }
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor> BatchStrideAs, __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
Array<ck::index_t, NumBTensor> BatchStrideBs,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{ {
if constexpr(!isMultiAB) Array<long_index_t, NumATensor> as_offset;
{ static_for<0, NumATensor, 1>{}(
static_assert("Invalid constructor for single A and B"); [&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); });
} return as_offset;
} }
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
{ {
if constexpr(!isMultiAB) Array<long_index_t, NumBTensor> bs_offset;
{ static_for<0, NumBTensor, 1>{}(
return g_idx * static_cast<long_index_t>(BatchStrideA_); [&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); });
} return bs_offset;
else
{
static_assert("Invalid function for multiple A or B");
return 0;
}
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{ {
if constexpr(!isMultiAB) Array<long_index_t, NumDTensor> ds_offset;
{ static_for<0, NumDTensor, 1>{}(
return g_idx * static_cast<long_index_t>(BatchStrideB_); [&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
} return ds_offset;
else
{
static_assert("Invalid function for multiple A or B");
return 0;
}
} }
__host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
if constexpr(isMultiAB) return g_idx * static_cast<long_index_t>(BatchStrideE_);
{
Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}([&](auto i) {
as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]);
});
return as_offset;
}
else
{
static_assert("Invalid function for single A and B");
return BatchStrideA_;
}
} }
__host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
if constexpr(isMultiAB) return g_idx * static_cast<long_index_t>(BatchStrideE_);
{ }
Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}([&](auto i) { Array<ck::index_t, NumATensor> BatchStrideA_;
bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); Array<ck::index_t, NumBTensor> BatchStrideB_;
}); Array<ck::index_t, NumDTensor> BatchStrideDs_;
return bs_offset; index_t BatchStrideE_;
} index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
else };
{
static_assert("Invalid function for single A and B"); template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
return BatchStrideB_; struct ComputePtrOffsetOfStridedBatch<NumATensor,
} NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor == 1 || NumBTensor == 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
...@@ -125,14 +123,8 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -125,14 +123,8 @@ struct ComputePtrOffsetOfStridedBatch
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
// If multiAB use Array ck::index_t BatchStrideA_;
using BatchStrideAType = ck::index_t BatchStrideB_;
std::conditional_t<isMultiAB, Array<ck::index_t, NumATensor>, ck::index_t>;
using BatchStrideBType =
std::conditional_t<isMultiAB, Array<ck::index_t, NumBTensor>, ck::index_t>;
BatchStrideAType BatchStrideA_;
BatchStrideBType BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -27,11 +27,8 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; ...@@ -27,11 +27,8 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC = static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment