Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aa5859e4
Commit
aa5859e4
authored
Aug 13, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into wavelet_model
parents
9bd6cc0e
5ee30459
Changes
278
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
498 additions
and
331 deletions
+498
-331
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+8
-0
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+1
-1
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+27
-0
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+4
-4
include/ck/utility/sequence_helper.hpp
include/ck/utility/sequence_helper.hpp
+2
-4
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+24
-5
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+55
-11
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+5
-4
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+25
-11
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+2
-2
library/CMakeLists.txt
library/CMakeLists.txt
+0
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+7
-6
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
...ibrary/reference_tensor_operation/cpu/reference_cgemm.hpp
+4
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+128
-107
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+97
-85
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+105
-83
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
...nsor_operation/cpu/reference_conv_fwd_bias_activation.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
..._operation/cpu/reference_conv_fwd_bias_activation_add.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
...reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
278 of 278+
files are displayed.
Plain diff
Email patch
include/ck/utility/math.hpp
View file @
aa5859e4
...
...
@@ -144,10 +144,18 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return
min
(
x
,
min
(
ys
...));
}
template
<
typename
T
>
__host__
__device__
constexpr
T
clamp
(
const
T
&
x
,
const
T
&
lowerbound
,
const
T
&
upperbound
)
{
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
...
...
include/ck/utility/reduction_functions_accumulate.hpp
View file @
aa5859e4
...
...
@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
if
(
!
isnan
(
currVal
))
if
(
!
ck
::
math
::
isnan
(
currVal
))
{
ReduceOperation
{}(
accuVal
,
currVal
);
}
...
...
include/ck/utility/reduction_operator.hpp
View file @
aa5859e4
...
...
@@ -58,6 +58,33 @@ struct Add
}
};
struct
SquaredAdd
{
template
<
class
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
class
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
a
=
a
+
b
*
b
;
}
};
struct
Mul
{
template
<
typename
T
>
...
...
include/ck/utility/sequence.hpp
View file @
aa5859e4
...
...
@@ -3,10 +3,10 @@
#pragma once
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "math.hpp"
#include "
ck/utility/
integral_constant.hpp"
#include "
ck/utility/
type.hpp"
#include "
ck/utility/
functional.hpp"
#include "
ck/utility/
math.hpp"
namespace
ck
{
...
...
include/ck/utility/sequence_helper.hpp
View file @
aa5859e4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#pragma once
#include "tuple.hpp"
#include "
ck/utility/
tuple.hpp"
namespace
ck
{
...
...
@@ -36,4 +35,3 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
}
}
// namespace ck
#endif
include/ck/utility/static_buffer.hpp
View file @
aa5859e4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#pragma once
#include "statically_indexed_array.hpp"
...
...
@@ -20,6 +19,22 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
template
<
typename
...
Ys
>
__host__
__device__
constexpr
StaticBuffer
&
operator
=
(
const
Tuple
<
Ys
...
>&
y
)
{
static_assert
(
base
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
StaticBuffer
&
x
=
*
this
;
static_for
<
0
,
base
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
x
(
i
)
=
y
[
i
];
});
return
x
;
}
__host__
__device__
constexpr
StaticBuffer
&
operator
=
(
const
T
&
y
)
{
StaticBuffer
&
x
=
*
this
;
static_for
<
0
,
base
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
x
(
i
)
=
y
;
});
return
x
;
}
__host__
__device__
static
constexpr
AddressSpaceEnum
GetAddressSpace
()
{
return
AddressSpace
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
...
...
@@ -40,10 +55,12 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
return
base
::
operator
()(
i
);
}
__host__
__device__
void
Clear
(
)
__host__
__device__
void
Set
(
T
x
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
T
{
0
};
});
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
T
{
x
};
});
}
__host__
__device__
void
Clear
()
{
Set
(
T
{
0
});
}
};
// static buffer for vector
...
...
@@ -61,6 +78,7 @@ struct StaticBufferTupleOfVector
static
constexpr
auto
s_per_v
=
Number
<
ScalarPerVector
>
{};
static
constexpr
auto
num_of_v_
=
Number
<
NumOfVector
>
{};
static
constexpr
auto
s_per_buf
=
s_per_v
*
num_of_v_
;
__host__
__device__
constexpr
StaticBufferTupleOfVector
()
:
base
{}
{}
...
...
@@ -70,6 +88,8 @@ struct StaticBufferTupleOfVector
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
s_per_buf
;
};
// Get S
// i is offset of S
template
<
index_t
I
>
...
...
@@ -173,4 +193,3 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
}
}
// namespace ck
#endif
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
aa5859e4
...
...
@@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template
<
typename
...
Ys
,
typename
X
>
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Ys
,
typename
X
>
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Xs
,
typename
Y
>
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
>
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
>
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
// MultiIndex = index_t * MultiIndex
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
operator
*
(
index_t
a
,
const
Tuple
<
Xs
...
>&
x
)
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
Y
a
,
const
Tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
...
...
@@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return
r
;
}
// MultiIndex = MultiIndex * index_t
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
index_t
a
)
// MultiIndex = MultiIndex * scalar
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
Y
a
)
{
return
a
*
x
;
}
namespace
mathext
{
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
exp
(
const
Tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
Tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
(
i
)
=
math
::
exp
(
x
[
i
]);
});
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
>
__host__
__device__
constexpr
auto
max
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
Tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
(
i
)
=
math
::
max
(
x
[
i
],
y
[
i
]);
});
return
r
;
}
}
// namespace mathext
template
<
typename
...
Xs
>
__host__
__device__
void
print_multi_index
(
const
Tuple
<
Xs
...
>&
x
)
{
...
...
include/ck/utility/synchronization.hpp
View file @
aa5859e4
...
...
@@ -18,14 +18,15 @@ __device__ void block_sync_lds()
__syncthreads
();
#endif
}
__device__
void
block_lds
()
__device__
void
s_nop
()
{
#if
CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#if
1
asm
volatile
(
"\
s_
waitcnt lgkmcnt(0)
\
s_
nop 0
\n
\
"
::
);
#else
__
syncthreads
(
);
__
builtin_amdgcn_sched_barrier
(
0
);
#endif
}
...
...
include/ck/utility/tuple.hpp
View file @
aa5859e4
...
...
@@ -3,10 +3,10 @@
#pragma once
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "type.hpp"
#include "enable_if.hpp"
#include "
ck/utility/
integral_constant.hpp"
#include "
ck/utility/
sequence.hpp"
#include "
ck/utility/
type.hpp"
#include "
ck/utility/
enable_if.hpp"
namespace
ck
{
...
...
@@ -21,6 +21,8 @@ struct TupleElementKey
template
<
typename
Key
,
typename
Data
>
struct
TupleElementKeyData
{
using
DataType
=
Data
;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
...
...
@@ -34,29 +36,40 @@ struct TupleElementKeyData
{
}
Data
mData
;
Data
Type
mData
;
};
// for read access of tuple element
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
get_tuple_element_data
_reference
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
static_cast
<
const
Data
&>
(
x
.
mData
);
}
// for write access of tuple element
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&
get_tuple_element_data
(
TupleElementKeyData
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
Data
&
get_tuple_element_data_reference
(
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
x
.
mData
;
}
// TODO: not sure the use of reference is correct
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&&
get_tuple_element_data
(
TupleElementKeyData
<
Key
,
Data
>&&
x
)
__host__
__device__
constexpr
Data
&&
get_tuple_element_data_reference
(
TupleElementKeyData
<
Key
,
Data
>&&
x
)
{
return
static_cast
<
Data
&&>
(
x
.
mData
);
}
// for infering type of tuple element
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
std
::
forward
(
x
.
mData
);
}
template
<
typename
Indices
,
typename
...
Xs
>
struct
TupleImpl
;
...
...
@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetElementDataByKey
(
TupleElementKey
<
I
>
)
const
{
return
get_tuple_element_data
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element_data
_reference
<
TupleElementKey
<
I
>>
(
*
this
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElementDataByKey
(
TupleElementKey
<
I
>
)
{
return
get_tuple_element_data
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element_data
_reference
<
TupleElementKey
<
I
>>
(
*
this
);
}
};
...
...
@@ -185,7 +198,8 @@ struct Tuple<>
template
<
index_t
I
,
typename
TTuple
>
struct
tuple_element
{
using
type
=
decltype
(
TTuple
{}.
At
(
Number
<
I
>
{}));
// type should keep the cv/ref qualifier of original tuple element
using
type
=
decltype
(
detail
::
get_tuple_element_data
<
detail
::
TupleElementKey
<
I
>>
(
TTuple
{}));
};
template
<
index_t
I
,
typename
TTuple
>
...
...
include/ck/utility/type.hpp
View file @
aa5859e4
...
...
@@ -4,8 +4,8 @@
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "enable_if.hpp"
#include "
ck/utility/
integral_constant.hpp"
#include "
ck/utility/
enable_if.hpp"
namespace
ck
{
...
...
library/CMakeLists.txt
View file @
aa5859e4
add_subdirectory
(
src/tensor_operation_instance/gpu
)
add_subdirectory
(
src/host_tensor
)
add_subdirectory
(
src/utility
)
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
aa5859e4
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -16,6 +16,7 @@ namespace host {
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
...
...
@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto
f_gmk_gkn_gmn
=
[
&
](
auto
g
,
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_g_m_k_
.
mDesc
.
GetLengths
()[
2
];
float
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
...
...
@@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg
.
a_element_op_
(
v_a
,
arg
.
a_g_m_k_
(
g
,
m
,
k
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_g_k_n_
(
g
,
k
,
n
));
v_acc
+=
ck
::
type_convert
<
float
>
(
v_a
)
*
ck
::
type_convert
<
float
>
(
v_b
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
float
v_c
;
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
@@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
make_ParallelTensorFunctor
(
f_gmk_gkn_gmn
,
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
1
],
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
2
])();
return
0
;
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
View file @
aa5859e4
...
...
@@ -6,8 +6,9 @@
#include <iostream>
#include <sstream>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/
library/host_tensor/host_tensor
.hpp"
#include "ck/
tensor_operation/gpu/element/element_wise_operation
.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -91,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_real
+=
v_a_real
*
v_b_real
-
v_a_imag
*
v_b_imag
;
}
arg
.
c_m_n_real_
(
m
,
n
)
=
v_c_real
;
arg
.
c_m_n_real_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c_real
)
;
};
auto
f_mk_kn_mn_imag
=
[
&
](
auto
m
,
auto
n
)
{
...
...
@@ -107,7 +108,7 @@ struct ReferenceCGemm : public device::BaseOperator
v_c_imag
+=
v_a_real
*
v_b_imag
+
v_a_imag
*
v_b_real
;
}
arg
.
c_m_n_imag_
(
m
,
n
)
=
v_c_imag
;
arg
.
c_m_n_imag_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c_imag
)
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn_real
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
aa5859e4
...
...
@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template
<
typename
InDataType
,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
weight_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
))
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
if
constexpr
(
NDimSpatial
==
1
)
{
auto
f_ncw
=
[
&
](
auto
g
,
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
GetLengths
()[
1
];
std
::
size_t
X
=
arg
.
weight_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
]);
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
wo
)));
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
x
)));
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
...
...
@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
arg
.
in_element_op_
(
v_acc
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])(
arg
.
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
auto
f_nchw
=
[
&
](
auto
g
,
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
GetLengths
()[
1
];
std
::
size_t
Y
=
arg
.
weight_
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
weight_
.
GetLengths
()[
4
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
auto
h_tmp
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
hi
)
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
]);
auto
h_tmp
=
static_cas
t
<
ck
::
long_index_t
>
(
hi
)
+
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
]);
if
(
h_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
auto
ho
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
h_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
auto
ho
=
static_cas
t
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
auto
wo
=
static_cas
t
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)));
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
y
,
x
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
...
...
@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType
v_in
;
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
],
arg
.
input_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Z
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
auto
f_ncdhw
=
[
&
](
auto
g
,
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
GetLengths
()[
1
];
std
::
size_t
Z
=
arg
.
weight_
.
GetLengths
()[
3
];
std
::
size_t
Y
=
arg
.
weight_
.
GetLengths
()[
4
];
std
::
size_t
X
=
arg
.
weight_
.
GetLengths
()[
5
];
std
::
size_t
Do
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
std
::
size_t
Do
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
4
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
5
];
AccDataType
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
z
=
0
;
z
<
Z
;
++
z
)
{
auto
d_tmp
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
di
)
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
]);
auto
d_tmp
=
static_cas
t
<
ck
::
long_index_t
>
(
di
)
+
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
]);
if
(
d_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
auto
do_
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
d_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
auto
do_
=
static_cas
t
<
ck
::
long_index_t
>
(
d_tmp
)
/
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
{
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
auto
h_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
hi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
auto
ho
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
h_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
auto
ho
=
static_cas
t
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
]);
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
2
]
==
0
)
{
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
2
]);
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
2
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
...
...
@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType
v_in
;
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
3
],
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])(
arg
.
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
],
arg
.
input_
.
GetLengths
()[
4
],
arg
.
input_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_b
ackwar
d_weight.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_b
w
d_weight.hpp
View file @
aa5859e4
...
...
@@ -7,21 +7,25 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template
<
typename
InDataType
,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
weight_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
))
{
constexpr
auto
I0
=
Number
<
0
>
{};
auto
f_kcx
=
[
&
](
auto
k
,
auto
c
,
auto
x
)
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
if
constexpr
(
NDimSpatial
==
1
)
{
auto
f_kcx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
];
++
n
)
{
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
3
];
++
wo
)
{
auto
wi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_
stride
s_
[
I
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilation
s_
[
I
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
arg
.
weight_
(
g
,
k
,
c
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
])(
arg
.
weight_
.
GetLengths
()[
0
],
arg
.
weight_
.
GetLengths
()[
1
],
arg
.
weight_
.
GetLengths
()[
2
],
arg
.
weight_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
];
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
3
];
++
ho
)
{
auto
hi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_
stride
s_
[
I
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilation
s_
[
I
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
4
];
++
wo
)
{
auto
wi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I
1
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilation
s_
[
I
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
static_cas
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
1
])
;
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
weight_
.
GetLengths
()[
0
],
arg
.
weight_
.
GetLengths
()[
1
],
arg
.
weight_
.
GetLengths
()[
2
],
arg
.
weight_
.
GetLengths
()[
3
],
arg
.
weight_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
f_kczyx
=
[
&
](
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
auto
f_kczyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
];
++
n
)
{
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
do_
)
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
GetLengths
()[
3
];
++
do_
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
ho
)
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
4
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
++
wo
)
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
5
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I2
]);
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
arg
.
input_
.
GetLengths
()[
3
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
arg
.
input_
.
GetLengths
()[
4
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
arg
.
input_
.
GetLengths
()[
5
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
arg
.
output_
(
g
,
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
...
...
@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kczyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
])(
arg
.
weight_
.
GetLengths
()[
0
],
arg
.
weight_
.
GetLengths
()[
1
],
arg
.
weight_
.
GetLengths
()[
2
],
arg
.
weight_
.
GetLengths
()[
3
],
arg
.
weight_
.
GetLengths
()[
4
],
arg
.
weight_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
aa5859e4
...
...
@@ -8,7 +8,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,9 +17,10 @@ namespace host {
//
// @brief Reference implementation for forward convolution.
//
// @paragraph Supports both NCHW as well as NHWC formats (and their respective
// counterparts for weight and output) as long as tensor descriptor
// lengths is in NCHW.
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
...
...
@@ -28,16 +29,20 @@ namespace host {
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam N
um
DimSpatial Number of spatial dimensions.
// @tparam NDimSpatial Number of spatial dimensions.
//
template
<
typename
InDataType
,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
std
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvFwd
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
weight_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
))
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
if
constexpr
(
NDimSpatial
==
1
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
3
];
++
x
)
{
auto
wi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_
stride
s_
[
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilation
s_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
x
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
...
...
@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
])(
make_ParallelTensorFunctor
(
func
,
arg
.
output_
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
auto
f
_
nc
hw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
GetLengths
()[
3
];
++
y
)
{
auto
hi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_
stride
s_
[
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilation
s_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
4
];
++
x
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
y
,
x
)));
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
}
...
...
@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
],
arg
.
output_
.
mDesc
.
GetLengths
()[
3
])(
make_ParallelTensorFunctor
(
func
,
arg
.
output_
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
],
arg
.
output_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
auto
f
_
nc
hw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
for
(
std
::
size_t
z
=
0
;
z
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
z
)
for
(
std
::
size_t
z
=
0
;
z
<
arg
.
weight_
.
GetLengths
()[
3
];
++
z
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
y
)
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
GetLengths
()[
4
];
++
y
)
{
auto
hi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
++
x
)
static_cas
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
5
];
++
x
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
arg
.
input_
.
GetLengths
()[
3
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
arg
.
input_
.
GetLengths
()[
4
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
arg
.
input_
.
GetLengths
()[
5
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
}
...
...
@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
],
arg
.
output_
.
mDesc
.
GetLengths
()[
3
],
arg
.
output_
.
mDesc
.
GetLengths
()[
4
])(
make_ParallelTensorFunctor
(
func
,
arg
.
output_
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
],
arg
.
output_
.
GetLengths
()[
4
],
arg
.
output_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
NDimSpatial
>=
1
&&
NDimSpatial
<=
3
;
}
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
View file @
aa5859e4
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
View file @
aa5859e4
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
aa5859e4
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
View file @
aa5859e4
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment