"profiler/vscode:/vscode.git/clone" did not exist on "aaa44490900ea830cc5bee5c891e0810ae2df675"
Commit 02153e24 authored by Astha Rai's avatar Astha Rai
Browse files

resolved issues with standard headers in device files: device_base.hpp,...

resolved issues with standard headers in device files: device_base.hpp, device_grouped_conv_fwd_multiple_abd.hpp, device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
parent 23b959b2
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
#pragma once #pragma once
#ifndef CK_CODE_GEN_RTC
#include <string> #include <string>
#include <sstream> #include <sstream>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#endif
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#ifndef CK_CODE_GEN_RTC
struct BaseArgument struct BaseArgument
{ {
BaseArgument() = default; BaseArgument() = default;
...@@ -36,13 +38,14 @@ struct BaseInvoker ...@@ -36,13 +38,14 @@ struct BaseInvoker
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
#endif
struct BaseOperator struct BaseOperator
{ {
BaseOperator() = default; BaseOperator() = default;
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
#ifndef CK_CODE_GEN_RTC
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
...@@ -66,7 +69,7 @@ struct BaseOperator ...@@ -66,7 +69,7 @@ struct BaseOperator
assert(p_arg); assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
#endif
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
#pragma once #pragma once
#ifndef CK_CODE_GEN_RTC
#include <array> #include <array>
#endif
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...@@ -13,8 +15,13 @@ namespace ck { ...@@ -13,8 +15,13 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#ifdef CK_CODE_GEN_RTC
template <typename T>
using is_tuple = decltype(ck::declval<T&>().IsTuple());
#else
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
#endif
/** /**
* \brief Grouped Convolution Forward * \brief Grouped Convolution Forward
...@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator ...@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
#ifdef CK_CODE_GEN_RTC
using APointers = ck::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
using BPointers = ck::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
#else
// If DataType is tuple, user has to pass std::array with pointers. // If DataType is tuple, user has to pass std::array with pointers.
using APointers = using APointers =
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>; ck::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
using BPointers = using BPointers =
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>; ck::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
#endif
#ifndef CK_CODE_GEN_RTC
/** /**
* \brief Make argument pointer for grouped conv fwd. * \brief Make argument pointer for grouped conv fwd.
...@@ -151,6 +164,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator ...@@ -151,6 +164,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
#endif
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -212,9 +212,13 @@ __global__ void ...@@ -212,9 +212,13 @@ __global__ void
} }
} // namespace } // namespace
#ifdef CK_CODE_GEN_RTC
template <typename T>
using is_tuple = decltype(ck::declval<T&>().IsTuple());
#else
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
#endif
// //
// @brief Device Convolution operation. // @brief Device Convolution operation.
......
...@@ -116,9 +116,13 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept ...@@ -116,9 +116,13 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{ {
return static_cast<T&&>(t_); return static_cast<T&&>(t_);
} }
template <typename T>
T&& declval() noexcept;
#else #else
#include <utility> #include <utility>
#include <type_traits> #include <type_traits>
using std::declval;
using std::forward; using std::forward;
using std::is_base_of; using std::is_base_of;
using std::is_class; using std::is_class;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment