Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a66e35b
Unverified
Commit
1a66e35b
authored
Feb 17, 2020
by
Chao Liu
Committed by
GitHub
Feb 17, 2020
Browse files
MIopen integration (#13)
* update for miopen integration: cosmetic refactor
parent
3406a114
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
202 additions
and
168 deletions
+202
-168
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+8
-6
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
...ble_kernel/include/utility/in_memory_operation.amd.hpp.in
+19
-18
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
..._kernel/include/utility/in_memory_operation.nvidia.hpp.in
+11
-12
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+9
-14
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+37
-7
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+32
-32
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+25
-25
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+50
-42
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+11
-12
No files found.
composable_kernel/include/utility/config.amd.hpp.in
View file @
1a66e35b
...
...
@@ -54,21 +54,23 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
namespace ck {
enum AddressSpace
{
g
eneric,
g
lobal,
l
ds,
v
gpr
G
eneric,
G
lobal,
L
ds,
V
gpr
};
enum InMemoryDataOperation
{
none
,
a
tomic
_a
dd
Set
,
A
tomic
A
dd
};
#if CK_UNSIGNED_INDEX_TYPE
...
...
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
View file @
1a66e35b
...
...
@@ -10,13 +10,14 @@ template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace>
__device__ void
copy
_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
__device__ void
set
_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
#if CK_USE_AMD_BUFFER_ADDRESSING
// TODO: use static_if::ElseIf, instead of nested static_if
static_if<SrcAddressSpace == AddressSpace::global && DstAddressSpace == vgpr>{}([&](auto) {
static_if<SrcAddressSpace == AddressSpace::Global &&
DstAddressSpace == AddressSpace::Vgpr>{}([&](auto) {
// buffer_load requires:
// 1) p_src must be in global memory space, d_dst must be vgpr
// 2) p_src to be a block-invariant pointer.
...
...
@@ -24,7 +25,8 @@ __device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
amd_intrinsic_buffer_load<T, DataPerAccess>(p_src, src_offset, 0);
}).Else([&](auto) {
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == global>{}([&](auto) {
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
// buffer_store requires:
// 1) p_src must be in vgpr space, d_dst must be global memory
// 2) p_dst to be a block-invariant pointer.
...
...
@@ -50,8 +52,8 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::
v
gpr &&
DstAddressSpace == AddressSpace::global>{}(
[&](auto) {
static_if<SrcAddressSpace == AddressSpace::
V
gpr &&
DstAddressSpace == AddressSpace::Global>{}(
[&](auto) {
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
...
...
@@ -59,8 +61,7 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
#endif
})
.Else([&](auto fwd) {
}).Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
}
...
...
@@ -72,17 +73,17 @@ template <typename T,
InMemoryDataOperation DstInMemOp>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::
none
||
DstInMemOp == InMemoryDataOperation::
a
tomic
_a
dd,
static_assert(DstInMemOp == InMemoryDataOperation::
Set
||
DstInMemOp == InMemoryDataOperation::
A
tomic
A
dd,
"wrong! InMemoryDataOperation not supported!");
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::
none
>{}([&](auto) {
copy
_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
static_if<DstInMemOp == InMemoryDataOperation::
Set
>{}([&](auto) {
set
_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
static_if<DstInMemOp == InMemoryDataOperation::
a
tomic
_a
dd>{}([&](auto) {
static_if<DstInMemOp == InMemoryDataOperation::
A
tomic
A
dd>{}([&](auto) {
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
...
...
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
View file @
1a66e35b
...
...
@@ -23,12 +23,11 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::
v
gpr &&
DstAddressSpace == AddressSpace::global>{}(
[&](auto) {
static_if<SrcAddressSpace == AddressSpace::
V
gpr &&
DstAddressSpace == AddressSpace::Global>{}(
[&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
})
.Else([&](auto fwd) {
}).Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
}
...
...
@@ -40,17 +39,17 @@ template <typename T,
InMemoryDataOperation DstInMemOp>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::
none
||
DstInMemOp == InMemoryDataOperation::
a
tomic
_a
dd,
static_assert(DstInMemOp == InMemoryDataOperation::
Set
||
DstInMemOp == InMemoryDataOperation::
A
tomic
A
dd,
"wrong! InMemoryDataOperation not supported!");
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::
none
>{}([&](auto) {
static_if<DstInMemOp == InMemoryDataOperation::
Set
>{}([&](auto) {
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
static_if<DstInMemOp == InMemoryDataOperation::
a
tomic
_a
dd>{}([&](auto) {
static_if<DstInMemOp == InMemoryDataOperation::
A
tomic
A
dd>{}([&](auto) {
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
...
...
composable_kernel/include/utility/math.hpp
View file @
1a66e35b
...
...
@@ -107,27 +107,22 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
template
<
typename
T
>
__host__
__device__
constexpr
T
gcd
(
T
x
,
T
y
)
{
if
(
x
==
0
)
if
(
x
==
y
||
x
==
0
)
{
return
y
;
}
if
(
y
==
0
)
{
return
x
;
}
if
(
x
==
y
)
else
if
(
y
==
0
)
{
return
x
;
}
if
(
x
>
y
)
else
if
(
x
>
y
)
{
return
gcd
(
x
-
y
,
y
);
}
else
{
return
gcd
(
x
,
y
-
x
);
}
}
template
<
index_t
X
,
index_t
Y
>
...
...
@@ -150,10 +145,10 @@ __host__ __device__ constexpr T lcm(T x, T y)
return
(
x
*
y
)
/
gcd
(
x
,
y
);
}
template
<
typename
X
,
typename
Y
,
typename
...
Z
s
>
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Y
y
,
Z
s
...
z
s
)
template
<
typename
X
,
typename
...
Y
s
>
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Ys
...
y
s
)
{
return
lcm
(
x
,
lcm
(
y
,
z
s
...));
return
lcm
(
x
,
lcm
(
ys
...));
}
template
<
class
T
>
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
1a66e35b
...
...
@@ -49,20 +49,20 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
...
...
@@ -79,6 +79,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif
1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#endif
constexpr
index_t
GemmM
=
C
*
Y
*
X
;
...
...
@@ -104,13 +134,13 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
1a66e35b
...
...
@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
...
...
@@ -96,13 +96,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
...
...
@@ -127,13 +127,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
...
...
@@ -152,33 +152,33 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#endif
constexpr
index_t
g
cd
_s
tride
_d
ilation
_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
g
cd
_s
tride
_d
ilation
_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
G
cd
S
tride
D
ilation
H
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
G
cd
S
tride
D
ilation
W
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Y
t
ilda
=
ConvStrideH
/
g
cd
_s
tride
_d
ilation
_h
;
constexpr
index_t
X
t
ilda
=
ConvStrideW
/
g
cd
_s
tride
_d
ilation
_w
;
constexpr
index_t
Y
T
ilda
=
ConvStrideH
/
G
cd
S
tride
D
ilation
H
;
constexpr
index_t
X
T
ilda
=
ConvStrideW
/
G
cd
S
tride
D
ilation
W
;
constexpr
index_t
Y
d
ot
=
math
::
integer_divide_ceil
(
Y
,
Y
t
ilda
);
constexpr
index_t
X
d
ot
=
math
::
integer_divide_ceil
(
X
,
X
t
ilda
);
constexpr
index_t
Y
D
ot
=
math
::
integer_divide_ceil
(
Y
,
Y
T
ilda
);
constexpr
index_t
X
D
ot
=
math
::
integer_divide_ceil
(
X
,
X
T
ilda
);
constexpr
index_t
H
t
ilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
W
t
ilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
H
T
ilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
W
T
ilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
H
t
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Y
t
ilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
W
t
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
X
t
ilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
H
T
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Y
T
ilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
W
T
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
X
T
ilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
H
t
ildaRight
=
math
::
min
(
H
t
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
W
t
ildaRight
=
math
::
min
(
W
t
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
H
T
ildaRight
=
math
::
min
(
H
T
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
W
T
ildaRight
=
math
::
min
(
W
T
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
H
t
ilda
Trim
=
H
t
ildaRight
-
H
t
ildaLeft
;
constexpr
index_t
W
t
ilda
Trim
=
W
t
ildaRight
-
W
t
ildaLeft
;
constexpr
index_t
H
T
ilda
Slice
=
H
T
ildaRight
-
H
T
ildaLeft
;
constexpr
index_t
W
T
ilda
Slice
=
W
T
ildaRight
-
W
T
ildaLeft
;
constexpr
index_t
GemmM
=
C
*
Y
t
ilda
*
X
t
ilda
;
constexpr
index_t
GemmN
=
N
*
H
t
ilda
Trim
*
W
t
ilda
Trim
;
constexpr
index_t
GemmM
=
C
*
Y
T
ilda
*
X
T
ilda
;
constexpr
index_t
GemmN
=
N
*
H
T
ilda
Slice
*
W
T
ilda
Slice
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
...
...
@@ -200,13 +200,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
1a66e35b
...
...
@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
...
...
@@ -91,33 +91,33 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
constexpr
index_t
g
cd
_s
tride
_d
ilation
_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
g
cd
_s
tride
_d
ilation
_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
G
cd
S
tride
D
ilation
H
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
G
cd
S
tride
D
ilation
W
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Y
t
ilda
=
ConvStrideH
/
g
cd
_s
tride
_d
ilation
_h
;
constexpr
index_t
X
t
ilda
=
ConvStrideW
/
g
cd
_s
tride
_d
ilation
_w
;
constexpr
index_t
Y
T
ilda
=
ConvStrideH
/
G
cd
S
tride
D
ilation
H
;
constexpr
index_t
X
T
ilda
=
ConvStrideW
/
G
cd
S
tride
D
ilation
W
;
constexpr
index_t
Y
d
ot
=
math
::
integer_divide_ceil
(
Y
,
Y
t
ilda
);
constexpr
index_t
X
d
ot
=
math
::
integer_divide_ceil
(
X
,
X
t
ilda
);
constexpr
index_t
Y
D
ot
=
math
::
integer_divide_ceil
(
Y
,
Y
T
ilda
);
constexpr
index_t
X
D
ot
=
math
::
integer_divide_ceil
(
X
,
X
T
ilda
);
constexpr
index_t
H
t
ilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
W
t
ilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
H
T
ilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
W
T
ilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
H
t
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Y
t
ilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
W
t
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
X
t
ilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
H
T
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Y
T
ilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
W
T
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
X
T
ilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
H
t
ildaRight
=
math
::
min
(
H
t
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
W
t
ildaRight
=
math
::
min
(
W
t
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
H
T
ildaRight
=
math
::
min
(
H
T
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
W
T
ildaRight
=
math
::
min
(
W
T
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
H
t
ilda
Trim
=
H
t
ildaRight
-
H
t
ildaLeft
;
constexpr
index_t
W
t
ilda
Trim
=
W
t
ildaRight
-
W
t
ildaLeft
;
constexpr
index_t
H
T
ilda
Slice
=
H
T
ildaRight
-
H
T
ildaLeft
;
constexpr
index_t
W
T
ilda
Slice
=
W
T
ildaRight
-
W
T
ildaLeft
;
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
H
t
ilda
Trim
*
W
t
ilda
Trim
;
constexpr
index_t
GemmN
=
N
*
H
T
ilda
Slice
*
W
T
ilda
Slice
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
...
...
@@ -139,13 +139,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
1a66e35b
...
...
@@ -69,13 +69,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
...
...
@@ -99,13 +99,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
...
...
@@ -124,33 +124,33 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
constexpr
index_t
g
cd
_s
tride
_d
ilation
_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
g
cd
_s
tride
_d
ilation
_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
G
cd
S
tride
D
ilation
H
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
G
cd
S
tride
D
ilation
W
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Y
t
ilda
=
ConvStrideH
/
g
cd
_s
tride
_d
ilation
_h
;
constexpr
index_t
X
t
ilda
=
ConvStrideW
/
g
cd
_s
tride
_d
ilation
_w
;
constexpr
index_t
Y
T
ilda
=
ConvStrideH
/
G
cd
S
tride
D
ilation
H
;
constexpr
index_t
X
T
ilda
=
ConvStrideW
/
G
cd
S
tride
D
ilation
W
;
constexpr
index_t
Y
d
ot
=
math
::
integer_divide_ceil
(
Y
,
Y
t
ilda
);
constexpr
index_t
X
d
ot
=
math
::
integer_divide_ceil
(
X
,
X
t
ilda
);
constexpr
index_t
Y
D
ot
=
math
::
integer_divide_ceil
(
Y
,
Y
T
ilda
);
constexpr
index_t
X
D
ot
=
math
::
integer_divide_ceil
(
X
,
X
T
ilda
);
constexpr
index_t
H
t
ilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
W
t
ilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
H
T
ilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
W
T
ilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
H
t
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Y
t
ilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
W
t
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
X
t
ilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
H
T
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Y
T
ilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
W
T
ildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
X
T
ilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
H
t
ildaRight
=
math
::
min
(
H
t
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
W
t
ildaRight
=
math
::
min
(
W
t
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
H
T
ildaRight
=
math
::
min
(
H
T
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
W
T
ildaRight
=
math
::
min
(
W
T
ilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
H
t
ilda
Trim
=
H
t
ildaRight
-
H
t
ildaLeft
;
constexpr
index_t
W
t
ilda
Trim
=
W
t
ildaRight
-
W
t
ildaLeft
;
constexpr
index_t
H
T
ilda
Slice
=
H
T
ildaRight
-
H
T
ildaLeft
;
constexpr
index_t
W
T
ilda
Slice
=
W
T
ildaRight
-
W
T
ildaLeft
;
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
H
t
ilda
Trim
*
W
t
ilda
Trim
;
constexpr
index_t
GemmN
=
N
*
H
T
ilda
Slice
*
W
T
ilda
Slice
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
...
...
@@ -159,7 +159,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
using
GridwiseConv
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
using
GridwiseConv
BwdData
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
...
...
@@ -174,13 +174,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
...
...
@@ -196,11 +196,18 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
KernelTimer
timer
;
timer
.
Start
();
static_for
<
0
,
GridwiseConv
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id_
)
{
static_for
<
0
,
GridwiseConv
BwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id_
)
{
constexpr
index_t
gemm_id
=
decltype
(
gemm_id_
){};
launch_kernel
(
run_gridwise_convolution_backward_data_v4r1
<
GridwiseConv
,
gemm_id
,
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
index_t
gemm_k
=
gemm_sizes
.
At
(
2
);
constexpr
bool
is_gemm_not_empty
=
gemm_k
>
0
;
// only compile and run if GEMM is no empty
static_if
<
is_gemm_not_empty
>
{}([
&
](
auto
fwd
)
{
launch_kernel
(
run_gridwise_convolution_backward_data_v4r1
<
GridwiseConvBwdData
,
fwd
(
gemm_id
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
...
...
@@ -212,6 +219,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
});
});
timer
.
End
();
float
time
=
timer
.
GetElapsedTime
();
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
1a66e35b
...
...
@@ -23,17 +23,16 @@ int main(int argc, char* argv[])
{
using
namespace
launcher
;
#if 0
// 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 1024;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#if 1
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using ConvStrides = Sequence<
2
,
2
>;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
...
...
@@ -158,7 +157,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif
1
#elif
0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
...
...
@@ -188,7 +187,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
0
#elif
1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment