Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0a265731
Commit
0a265731
authored
May 30, 2019
by
Chao Liu
Browse files
adding implicit gemm v4 (nchw, kcyx)
parent
8c385cf5
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
747 additions
and
27 deletions
+747
-27
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...er/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+137
-0
driver/driver.hip.cpp
driver/driver.hip.cpp
+3
-0
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+13
-0
src/include/ConstantMergedTensorDescriptor.hip.hpp
src/include/ConstantMergedTensorDescriptor.hip.hpp
+44
-3
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+74
-0
src/include/blockwise_generic_tensor_slice_op.hip.hpp
src/include/blockwise_generic_tensor_slice_op.hip.hpp
+147
-13
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+2
-2
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
+2
-2
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp
...dwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp
+318
-0
No files found.
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
0a265731
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#endif
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 1
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
N1
,
N2
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/driver.hip.cpp
View file @
0a265731
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
struct
GeneratorTensor_1
struct
GeneratorTensor_1
{
{
...
@@ -629,6 +630,8 @@ int main(int argc, char* argv[])
...
@@ -629,6 +630,8 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 1
#elif 1
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
src/include/Array.hip.hpp
View file @
0a265731
...
@@ -135,3 +135,16 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
...
@@ -135,3 +135,16 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
return
result
;
return
result
;
}
}
template
<
class
TData
,
index_t
NSize
,
class
F
>
__host__
__device__
constexpr
TData
reduce_on_array
(
Array
<
TData
,
NSize
>
a
,
F
f
)
{
TData
result
=
a
[
0
];
static_for
<
1
,
NSize
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
i
=
I
.
Get
();
result
=
f
(
result
,
a
[
i
]);
});
return
result
;
}
src/include/ConstantMergedTensorDescriptor.hip.hpp
View file @
0a265731
...
@@ -33,9 +33,10 @@ struct ConstantMergedTensorDescriptor
...
@@ -33,9 +33,10 @@ struct ConstantMergedTensorDescriptor
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfOriginalDimension
()
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
{
{
return
n
OriginalDim
;
return
std
::
get
<
IDim
>
(
m
OriginalDim
MergeSeqs
)
;
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
...
@@ -98,7 +99,15 @@ struct ConstantMergedTensorDescriptor
...
@@ -98,7 +99,15 @@ struct ConstantMergedTensorDescriptor
return
original_multi_id
;
return
original_multi_id
;
}
}
__host__
__device__
static
index_t
GetOffsetFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
#if 0 // not needed
__host__ __device__ static index_t
GetOffsetFromOriginalMultiIndex(Array<index_t, nOriginalDim> original_multi_id)
{
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
#endif
__host__
__device__
static
index_t
GetOffsetFromMultiIndexA
(
Array
<
index_t
,
nDim
>
multi_id
)
{
{
const
auto
original_multi_id
=
GetOriginalMultiIndexFromMultiIndex
(
multi_id
);
const
auto
original_multi_id
=
GetOriginalMultiIndexFromMultiIndex
(
multi_id
);
...
@@ -117,6 +126,38 @@ struct ConstantMergedTensorDescriptor
...
@@ -117,6 +126,38 @@ struct ConstantMergedTensorDescriptor
return
dummy_desc
.
GetMultiIndexFrom1dIndex
(
id
);
return
dummy_desc
.
GetMultiIndexFrom1dIndex
(
id
);
}
}
#if 0 // not needed
template <index_t IDim>
__host__ __device__ static index_t GetNewOriginalMultiIndexAfterMovingAlongOneDimension(
Array<index_t, nOriginalDim> old_original_multi_id, Number<IDim>, index_t step_size)
{
auto new_original_multi_id = old_original_multi_id;
// get partial-original-multi-id corresponding to this merged dimension
constexpr auto original_partial_dims = std::get<IDim>(mOriginalDimMergeSeqs);
constexpr auto original_partial_tensor_desc =
OriginalTensorDesc::Extract(original_partial_dims);
auto old_original_partial_multi_id =
extract_array(old_original_mutli_id, original_paritial_dims);
auto new_original_partial_multi_id =
original_partial_tensor_desc.GetNewMultiIndexGivenStepSizeOf1dIndex(
old_original_partial_multi_id, step_size);
// update original-mutli-id
static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){};
constexpr index_t idim_original = original_dims_partial.Get(I);
new_original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
});
return new_original_multi_id;
}
#endif
};
};
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
0a265731
...
@@ -40,6 +40,14 @@ struct ConstantTensorDescriptor
...
@@ -40,6 +40,14 @@ struct ConstantTensorDescriptor
#endif
#endif
}
}
__host__
__device__
static
constexpr
auto
GetOriginalTensorDescriptor
()
{
return
Type
{};
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
{
return
Sequence
<
IDim
>
{};
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
...
@@ -66,6 +74,19 @@ struct ConstantTensorDescriptor
...
@@ -66,6 +74,19 @@ struct ConstantTensorDescriptor
return
MemoryRanks
{}.
Get
(
Number
<
I
>
{});
return
MemoryRanks
{}.
Get
(
Number
<
I
>
{});
}
}
__host__
__device__
static
constexpr
bool
AreStridesNonAscending
()
{
bool
flag
=
true
;
static_for
<
0
,
nDim
-
1
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
IDim_p1
=
Number
<
IDim
.
Get
()
+
1
>
{};
flag
=
flag
&&
(
GetLength
(
IDim
)
>=
GetLength
(
IDim_p1
));
});
return
flag
;
}
template
<
class
T
>
template
<
class
T
>
__host__
__device__
static
constexpr
bool
ContainMultipleOriginalDimensions
(
T
)
__host__
__device__
static
constexpr
bool
ContainMultipleOriginalDimensions
(
T
)
{
{
...
@@ -167,6 +188,46 @@ struct ConstantTensorDescriptor
...
@@ -167,6 +188,46 @@ struct ConstantTensorDescriptor
return
multi_id
;
return
multi_id
;
}
}
__host__
__device__
static
auto
GetOriginalMultiIndexFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
return
multi_id
;
}
// This function doesn't do carry check on the highest dimension, for performance reason.
// It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound
// on the highest dimension
__host__
__device__
static
Array
<
index_t
,
nDim
>
UpdateMultiIndexGivenStepSizeOf1dIndex
(
Array
<
index_t
,
nDim
>
old_multi_id
,
index_t
step_size_of_1d_index
)
{
auto
new_multi_id
=
old_multi_id
+
GetMultiIndexFrom1dIndex
(
step_size_of_1d_index
);
bool
carry
=
false
;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
nDim
-
1
,
1
>
{}([
&
](
auto
IDimReverse
)
{
constexpr
index_t
idim
=
nDim
-
1
-
IDimReverse
.
Get
();
constexpr
auto
IDim
=
Number
<
idim
>
{};
if
(
carry
)
{
++
new_multi_id
[
idim
];
}
carry
=
false
;
if
(
new_multi_id
[
idim
]
>=
GetLength
(
IDim
))
{
new_multi_id
[
idim
]
-=
GetLength
(
IDim
);
carry
=
true
;
}
});
return
new_multi_id
;
}
// WRONG! Ranks is broken
// WRONG! Ranks is broken
template
<
index_t
...
IDims
>
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
IDims
>
...
extract_dims
)
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
IDims
>
...
extract_dims
)
...
@@ -193,6 +254,19 @@ struct ConstantTensorDescriptor
...
@@ -193,6 +254,19 @@ struct ConstantTensorDescriptor
return
Extract
(
Number
<
IDims
>
{}...);
return
Extract
(
Number
<
IDims
>
{}...);
}
}
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
Inject
(
ConstantTensorDescriptor
<
Ts
...
>
)
{
using
leaf_tensor
=
ConstantTensorDescriptor
<
Ts
...
>
;
// memory rank is broken
// TODO: remove memory rank info from tensor descritpor
return
ConstantTensorDescriptor
<
decltype
(
GetLengths
().
Append
(
leaf_tensor
::
GetLengths
())),
decltype
(
GetStrides
().
Append
(
leaf_tensor
::
GetStrides
())),
decltype
(
GetMemoryRanks
().
Append
(
leaf_tensor
::
GetMemoryRanks
()))
>
{};
}
template
<
index_t
IDim
,
index_t
SliceLen
>
template
<
index_t
IDim
,
index_t
SliceLen
>
__host__
__device__
static
constexpr
auto
Slice
(
Number
<
IDim
>
,
Number
<
SliceLen
>
)
__host__
__device__
static
constexpr
auto
Slice
(
Number
<
IDim
>
,
Number
<
SliceLen
>
)
{
{
...
...
src/include/blockwise_generic_tensor_slice_op.hip.hpp
View file @
0a265731
#pragma once
#pragma once
#include "threadwise_tensor_slice_op.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
// slice a (normal or merged) tensor, reorder and copy it into another (normal or merged) tensor
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
Float
,
class
SrcDesc
,
class
SrcDesc
,
...
@@ -18,8 +19,29 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -18,8 +19,29 @@ struct BlockwiseGenericTensorSliceCopy_v1
{
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
index_t
mSrcMyThreadOffset
;
static
constexpr
index_t
nOriginalDimSrc
=
index_t
mDstMyThreadOffset
;
SrcDesc
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDimDst
=
DstDesc
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
// per-thread offset
index_t
mThreadSrcOffset
;
index_t
mThreadDstOffset
;
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
// updated if slicing-window is moved. However, they will not be used if you always move
// the slicing-window along a non-merged dimension. In that case, compiler should be
// able to remove these calculation.
// TODO: make sure compiler would actually remove them in that case
// partial offset in each (merged) dimension
Array
<
index_t
,
nDim
>
mThreadSrcPartialOffsets
;
Array
<
index_t
,
nDim
>
mThreadDstPartialOffsets
;
// multi-id of original tensor
Array
<
index_t
,
nOriginalDimSrc
>
mThreadSrcOriginalMultiId
;
Array
<
index_t
,
nOriginalDimDst
>
mThreadDstOriginalMultiId
;
__device__
__device__
BlockwiseGenericTensorSliceCopy_v1
(
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
BlockwiseGenericTensorSliceCopy_v1
(
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
...
@@ -72,7 +94,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -72,7 +94,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
"wrong! only surpport Sub-Length == 1 on a merged dimension"
);
"wrong! only surpport Sub-Length == 1 on a merged dimension"
);
});
});
// calculate m
SrcMy
ThreadOffset, m
DstMy
ThreadOffset
// calculate mThread
Src
Offset, mThread
Dst
Offset
const
auto
thread_cluster_multi_id
=
const
auto
thread_cluster_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
...
@@ -81,11 +103,46 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -81,11 +103,46 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
auto
thread_data_multi_id_begin
=
data_cluster_multi_id
*
SubLengths
{};
const
auto
thread_data_multi_id_begin
=
data_cluster_multi_id
*
SubLengths
{};
mSrcMyThreadOffset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_block_data_multi_id_begin
+
// original multi-id
thread_data_multi_id_begin
);
mThreadSrcOriginalMultiId
=
SrcDesc
::
GetOriginalMultiIndexFromMultiIndex
(
src_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
mThreadDstOriginalMultiId
=
DstDesc
::
GetOriginalMultiIndexFromMultiIndex
(
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
.
Get
();
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
mThreadSrcPartialOffsets
[
idim
]
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
});
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
.
Get
();
constexpr
auto
dst_partial_original_dims
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
dst_partial_original_desc
=
DstDesc
::
GetOriginalTensorDescriptor
().
Extract
(
dst_partial_original_dims
);
mThreadDstPartialOffsets
[
idim
]
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadDstOriginalMultiId
,
dst_partial_original_dims
));
});
// complete offset
mThreadSrcOffset
=
reduce_on_array
(
mThreadSrcPartialOffsets
,
std
::
plus
<
index_t
>
{});
mThreadDstOffset
=
reduce_on_array
(
mThreadDstPartialOffsets
,
std
::
plus
<
index_t
>
{});
mDstMyThreadOffset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
#if 0
#if 0
{
{
printf("id %5u %5u: "
printf("id %5u %5u: "
...
@@ -93,7 +150,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -93,7 +150,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
"thread_cluster_multi_id: %u %u %u %u, "
"thread_cluster_multi_id: %u %u %u %u, "
"data_cluster_multi_id: %u %u %u %u, "
"data_cluster_multi_id: %u %u %u %u, "
"thread_data_multi_id_begin: %u %u %u %u, "
"thread_data_multi_id_begin: %u %u %u %u, "
"m
SrcMy
ThreadOffset %u, m
DstMy
ThreadOffset %u \n",
"mThread
Src
Offset %u, mThread
Dst
Offset %u \n",
get_block_1d_id(),
get_block_1d_id(),
get_thread_local_1d_id(),
get_thread_local_1d_id(),
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[0],
...
@@ -112,8 +169,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -112,8 +169,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_data_multi_id_begin[1],
thread_data_multi_id_begin[1],
thread_data_multi_id_begin[2],
thread_data_multi_id_begin[2],
thread_data_multi_id_begin[3],
thread_data_multi_id_begin[3],
m
SrcMy
ThreadOffset,
mThread
Src
Offset,
m
DstMy
ThreadOffset);
mThread
Dst
Offset);
}
}
#endif
#endif
}
}
...
@@ -156,7 +213,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -156,7 +213,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
clipboard_data_multi_id_begin
);
// cannot not constexpr, why?
clipboard_data_multi_id_begin
);
// cannot not constexpr, why?
threadwise_generic_tensor_slice_copy
(
SrcDesc
{},
threadwise_generic_tensor_slice_copy
(
SrcDesc
{},
p_src
+
src_offset
+
m
SrcMy
ThreadOffset
,
p_src
+
src_offset
+
mThread
Src
Offset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
thread_tensor_desc
,
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
...
@@ -197,7 +254,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -197,7 +254,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
DstDesc
{},
DstDesc
{},
p_dst
+
dst_offset
+
m
DstMy
ThreadOffset
,
p_dst
+
dst_offset
+
mThread
Dst
Offset
,
make_zero_array
<
index_t
,
nDim
>
(),
make_zero_array
<
index_t
,
nDim
>
(),
thread_sub_tensor_lengths
,
thread_sub_tensor_lengths
,
DstAccessOrder
{});
DstAccessOrder
{});
...
@@ -211,4 +268,81 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -211,4 +268,81 @@ struct BlockwiseGenericTensorSliceCopy_v1
RunLoadRegisterClipboard
(
p_src
,
p_clipboard
);
RunLoadRegisterClipboard
(
p_src
,
p_clipboard
);
RunStoreRegisterClipboard
(
p_clipboard
,
p_dst
);
RunStoreRegisterClipboard
(
p_clipboard
,
p_dst
);
}
}
// When moving the slicing windows along a merged dimension, if the strides of the
// contained (by the merged dimension) original dimensions are in descending order,
// then there is no guarantee that the new offset will be larger than the old offset
// for movement in positive direction (vice versue for movement in negative direction).
// As a result, there is the possiblity that the offset calculation may result in
// unsigned integer underflow (due to "-" operation). However, this hazard should not
// happen, as long as the users make sure the slicing window would not be moved out of
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
// check on out-of-bound slicing window, for performance reason
template
<
index_t
IDim_
,
index_t
StepSize
,
bool
PositiveDirection
>
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
static_assert
(
PositiveDirection
,
"wrong! only support movement in positive direction for now"
);
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
index_t
idim
=
IDim
.
Get
();
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
fwd
)
{
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove uselss calculations
// extract partial original dimensions
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
// calculate new partial original multi-id
auto
old_src_partial_original_multi_id
=
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
);
auto
new_src_partial_original_multi_id
=
src_partial_original_desc
.
UpdateMultiIndexGivenStepSizeOf1dIndex
(
old_src_partial_original_multi_id
,
StepSize
);
// update "mThreadSrcOriginalMultiId"
static_for
<
0
,
src_partial_original_dims
.
GetSize
(),
1
>
{}([
&
](
auto
I_
)
{
constexpr
auto
I
=
decltype
(
I_
){};
constexpr
index_t
idim_original
=
src_partial_original_dims
.
Get
(
I
);
mThreadSrcOriginalMultiId
[
idim_original
]
=
new_src_partial_original_multi_id
[
I
.
Get
()];
});
// calculate new partial offset on this merged dimension
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
idim
];
const
index_t
new_src_partial_offset
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
new_src_partial_original_multi_id
);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets
[
idim
]
=
new_src_partial_offset
;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset
=
mThreadSrcOffset
+
new_src_partial_offset
-
old_src_partial_offset
;
}).
Else
([
&
](
auto
fwd
)
{
// Logic for non-merged dimension. If you are never going to move the slicing window on
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
// which are being calculated here, will never be used later. In this case, compiler
// should be able to remove these calculations.
// TODO: make sure compiler would actually remove them in this case.
constexpr
index_t
idim_original
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
mThreadSrcOffset
+=
StepSize
*
SrcDesc
::
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
[
idim_original
]
+=
StepSize
;
mThreadSrcPartialOffsets
[
idim
]
+=
StepSize
*
SrcDesc
::
GetStride
(
IDim
);
});
}
};
};
src/include/functional.hip.hpp
View file @
0a265731
...
@@ -44,7 +44,7 @@ struct static_if<true>
...
@@ -44,7 +44,7 @@ struct static_if<true>
}
}
template
<
class
F
>
template
<
class
F
>
__host__
__device__
static
constexpr
auto
e
lse
_
(
F
)
__host__
__device__
static
constexpr
auto
E
lse
(
F
)
{
{
return
Type
{};
return
Type
{};
}
}
...
@@ -62,7 +62,7 @@ struct static_if<false>
...
@@ -62,7 +62,7 @@ struct static_if<false>
}
}
template
<
class
F
>
template
<
class
F
>
__host__
__device__
static
constexpr
auto
e
lse
_
(
F
f
)
__host__
__device__
static
constexpr
auto
E
lse
(
F
f
)
{
{
// This is a trick for compiler:
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
...
...
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
@@ -337,7 +337,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
...
@@ -337,7 +337,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
@@ -373,7 +373,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -373,7 +373,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
@@ -363,7 +363,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -363,7 +363,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
@@ -412,7 +412,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -412,7 +412,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
View file @
0a265731
...
@@ -432,7 +432,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
...
@@ -432,7 +432,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
View file @
0a265731
...
@@ -115,7 +115,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -115,7 +115,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
Number
<
InBlockReorderDataPerWrite_N
>
{});
Number
<
InBlockReorderDataPerWrite_N
>
{});
// this check is ad-hoc
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with
alignment
// TODO: need to properly implement tensor descriptor with
multiple alignment requirements
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
"GemmDataPerReadB alignment requirement is not meet"
);
...
@@ -417,7 +417,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -417,7 +417,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
// Number<OutThreadCopyDataPerWrite_W>{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
View file @
0a265731
...
@@ -407,7 +407,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
...
@@ -407,7 +407,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
View file @
0a265731
...
@@ -366,7 +366,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -366,7 +366,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
// Number<OutThreadCopyDataPerWrite_W>{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp
0 → 100644
View file @
0a265731
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMergedTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_generic_tensor_slice_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2
,
class
InBlockCopyClusterLengths_E_N1_B_N2
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
TRUE
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_default_rank_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, H, W]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Hi
>
{})
.
Slice
(
I3
,
Number
<
Wi
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Inject
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_default_rank_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
Sequence
<
0
,
1
,
3
,
2
>
,
// thread_arrange_order [E, N1, N2, B]
Sequence
<
0
,
1
,
3
,
2
>
,
// src_access_order [E, N1, N2, B]
Sequence
<
0
,
1
,
2
,
3
>
,
// dst_access_order [E, N1, B, N2]
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_default_rank_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
Sequence
<
1
,
0
>
,
// thread_arrange_order [K, E]
Sequence
<
1
,
0
>
,
// src_access_order [K, E]
Sequence
<
0
,
1
>
,
// dst_access_order [E, K]
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_e_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
)
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_e_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
// do work
for
(
index_t
e
=
0
;
e
<
E
;
e
+=
EPerBlock
)
{
// marching slicing window
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
TRUE
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
TRUE
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
(
K1
*
K2
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_default_rank_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
SeqType
{});
}
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment