Commit dd6a8de4 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'develop' into jd/dev_pkg

parents 0aa899aa abf4bdb9
...@@ -14,7 +14,7 @@ namespace ck { ...@@ -14,7 +14,7 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. Run() does not construct new tensor coordinate // 3. Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -5,7 +5,7 @@ namespace ck { ...@@ -5,7 +5,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum ConvolutionBackwardDataSpecialization_t enum struct ConvolutionBackwardDataSpecialization
{ {
Default, Default,
Filter1x1Stride1Pad0, Filter1x1Stride1Pad0,
......
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION
#include <string>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum ConvolutionForwardSpecialization_t enum struct ConvolutionForwardSpecialization
{ {
Default, Default,
Filter1x1Pad0, Filter1x1Pad0,
...@@ -13,6 +15,18 @@ enum ConvolutionForwardSpecialization_t ...@@ -13,6 +15,18 @@ enum ConvolutionForwardSpecialization_t
OddC, OddC,
}; };
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s)
{
switch(s)
{
case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization::OddC: return "OddC";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
......
...@@ -11,7 +11,7 @@ namespace device { ...@@ -11,7 +11,7 @@ namespace device {
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator struct DeviceConvBwdWeight : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
...@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator ...@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr< using DeviceConvBwdWeightPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>; DeviceConvBwdWeight<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment