Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7bbcd0fe
Commit
7bbcd0fe
authored
Dec 12, 2020
by
Jing Zhang
Browse files
vector type demo for fp32
parent
bbdb77e8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
577 deletions
+68
-577
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
...sor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
+33
-106
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+31
-0
driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+3
-3
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-468
No files found.
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
View file @
7bbcd0fe
...
...
@@ -96,17 +96,23 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template
<
typename
SrcCoord
>
__device__
static
float
run
(
const
float
*
p_src
,
const
SrcCoord
src_coord_begin
)
{
constexpr
auto
vector_access_dim
=
Number
<
SrcDstVectorReadWriteDim
>
{};
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
float
r
;
r
=
load_data
<
float
,
float
>
(
p_src
,
src_coord_begin
.
GetOffset
());
return
r
;
}
};
scalar_id
(
vector_access_dim
)
=
0
;
auto
src_coord
=
src_coord_begin
+
scalar_id
;
r
=
load_data
<
float
,
float
>
(
p_src
,
src_coord
.
GetOffset
());
template
<
typename
DstData
,
index_t
DstDataPerAccess
,
index_t
VectorSize
>
struct
vector_data_store
;
return
r
;
template
<
>
struct
vector_data_store
<
float
,
1
,
1
>
{
template
<
typename
DstCoord
>
__device__
static
void
run
(
float
*
p_dst
,
const
float
src_data
,
const
DstCoord
dst_coord_begin
)
{
store_data
<
float
,
float
>
(
src_data
,
p_dst
,
dst_coord_begin
.
GetOffset
());
}
};
...
...
@@ -131,56 +137,20 @@ struct ThreadwiseGenericTensorSliceCopy_v5
Number
<
vector_access_dim
>
{},
Number
<
long_vector_size
*
long_vector_access_id
[
vector_access_dim
]
>
{});
// buffer to hold a src long-vector
SrcData
long_vector
[
long_vector_size
];
#if 1
// zero out buffer
static_for
<
0
,
long_vector_size
,
1
>
{}([
&
](
auto
i
)
{
long_vector
[
i
]
=
0
;
});
#endif
// load data from src to the long-vector buffer
static_for
<
0
,
long_vector_size
/
src_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
auto
scalar_id
=
make_zero_multi_index
<
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
src_data_per_access
;
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
const
auto
src_coord
=
mSrcSliceOrigin
+
(
to_multi_index
(
long_vector_data_begin_id
)
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data
<
SrcData
,
SrcDataPerRead
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
SrcDataStride
,
1
>
(
p_src
,
src_coord
.
GetOffset
(),
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
(),
SrcDesc
::
GetElementSpace
(),
long_vector
,
buffer_offset
,
true
,
long_vector_size
);
});
const
auto
src_coord
=
mSrcSliceOrigin
+
to_multi_index
(
long_vector_data_begin_id
);
// store data from the long-vector buffer to dst
static_for
<
0
,
long_vector_size
/
dst_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
auto
scalar_id
=
make_zero_multi_index
<
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
dst_data_per_access
;
auto
src_buff
=
vector_data_load
<
SrcData
,
SrcDataPerRead
,
long_vector_size
>::
run
(
p_src
,
src_coord
);
const
index_t
buffer_offset
=
i
*
dst_data_per_access
;
// store data from the long-vector buffer to dst
constexpr
auto
buff_off
=
ThreadBufferDesc
::
CalculateOffset
(
to_multi_index
(
long_vector_data_begin_id
));
constexpr
auto
buff_off
=
ThreadBufferDesc
::
CalculateOffset
(
to_multi_index
(
long_vector_data_begin_id
)
);
// static_assert(buff_off == 0 || buff_off == 1 || buff_off == 2 || buff_off == 3,
// ""
);
thread_buff
.
s1
(
Number
<
buff_off
>
{})
=
long_vector
[
buffer_offset
];
});
thread_buff
.
s1
(
Number
<
buff_off
>
{})
=
src_buff
;
});
}
...
...
@@ -201,62 +171,19 @@ struct ThreadwiseGenericTensorSliceCopy_v5
static_ford
<
decltype
(
long_vector_access_lengths
),
SrcDstDimAccessOrder
>
{}(
[
&
](
auto
long_vector_access_id
)
{
// data id w.r.t slicing-window
auto
long_vector_data_begin_id
=
to_multi_index
(
long_vector_access_id
);
long_vector_data_begin_id
(
vector_access_dim
)
=
long_vector_size
*
long_vector_access_id
[
vector_access_dim
];
// buffer to hold a src long-vector
DstData
long_vector
[
long_vector_size
];
#if 1
// zero out buffer
static_for
<
0
,
long_vector_size
,
1
>
{}([
&
](
auto
i
)
{
long_vector
[
i
]
=
0
;
});
#endif
// load data from src to the long-vector buffer
static_for
<
0
,
long_vector_size
/
src_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
auto
scalar_id
=
make_zero_multi_index
<
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
src_data_per_access
;
constexpr
auto
long_vector_data_begin_id
=
long_vector_access_id
.
Modify
(
Number
<
vector_access_dim
>
{},
Number
<
long_vector_size
*
long_vector_access_id
[
vector_access_dim
]
>
{});
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
constexpr
auto
buff_off
=
ThreadBufferDesc
::
CalculateOffset
(
to_multi_index
(
long_vector_data_begin_id
));
auto
buff_off
=
ThreadBufferDesc
::
CalculateOffset
(
long_vector_data_begin_id
+
scalar_id
);
auto
src_buff
=
thread_buff
.
s1
[
Number
<
buff_off
>
{}];
// long_vector[buffer_offset] = thread_buff.s1[Number<buff_off>{}];
long_vector
[
buffer_offset
]
=
thread_buff
.
n
[
buff_off
];
});
const
auto
dst_coord
=
mDstSliceOrigin
+
to_multi_index
(
long_vector_data_begin_id
);
// store data from the long-vector buffer to dst
static_for
<
0
,
long_vector_size
/
dst_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
auto
scalar_id
=
make_zero_multi_index
<
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
dst_data_per_access
;
const
index_t
buffer_offset
=
i
*
dst_data_per_access
;
const
auto
dst_coord
=
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data
<
DstData
,
DstDataPerWrite
,
AddressSpace
::
Vgpr
,
DstAddressSpace
,
DstInMemOp
,
1
,
DstDataStride
>
(
long_vector
,
buffer_offset
,
true
,
long_vector_size
,
p_dst
,
dst_coord
.
GetOffset
(),
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
(),
DstDesc
::
GetElementSpace
());
});
vector_data_store
<
DstData
,
DstDataPerWrite
,
long_vector_size
>::
run
(
p_dst
,
src_buff
,
dst_coord
);
});
}
...
...
@@ -280,7 +207,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
.
Else
([
&
](
auto
)
{
mDstSliceOrigin
-=
step_sizes
;
});
}
float_vec
8
_t
thread_buff
;
float_vec
4
_t
thread_buff
;
private:
SrcCoord
mSrcSliceOrigin
;
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
7bbcd0fe
...
...
@@ -32,8 +32,39 @@ union float_vec2_t
union float_vec4_t
{
Tuple<float, float, float, float> s1;
struct{
float e0, e1, e2, e3;
} ss1;
float4_t s4;
float n[4];
__host__ __device__ constexpr float_vec4_t() {}
template<typename T, index_t i>
__host__ __device__ void set(const T val);
template<>
__host__ __device__ void set<float, 0>(const float val)
{
ss1.e0 = val;
}
template<>
__host__ __device__ void set<float, 1>(const float val)
{
ss1.e1 = val;
}
template<>
__host__ __device__ void set<float, 2>(const float val)
{
ss1.e2 = val;
}
template<>
__host__ __device__ void set<float, 3>(const float val)
{
ss1.e3 = val;
}
};
union float_vec8_t
...
...
driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
7bbcd0fe
...
...
@@ -120,7 +120,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 256, 128x128x8
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -183,8 +183,8 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#elif 0
...
...
driver/src/conv_driver.cpp
View file @
7bbcd0fe
...
...
@@ -22,487 +22,20 @@ int main(int argc, char* argv[])
{
using
namespace
ck
;
#if 0
// 1x1, 8x8
constexpr index_t N = 2;
constexpr index_t C = 24;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
160
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
96
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 0
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
299
;
constexpr
index_t
WI
=
299
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 149x149
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
149
;
constexpr
index_t
WI
=
149
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 17x17, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x3, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
448
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
1
>
;
using
RightPads
=
Sequence
<
0
,
1
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 0
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 3x3, 28x28
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x7, 230x230 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
230
;
constexpr
index_t
WI
=
230
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride = 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#endif
auto
in_nchw_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment