Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
0a265731
"vscode:/vscode.git/clone" did not exist on "d3ef08220510d26fe066472b0cbcb48ca9978286"
Commit
0a265731
authored
May 30, 2019
by
Chao Liu
Browse files
adding implicit gemm v4 (nchw, kcyx)
parent
8c385cf5
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
747 additions
and
27 deletions
+747
-27
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...er/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+137
-0
driver/driver.hip.cpp
driver/driver.hip.cpp
+3
-0
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+13
-0
src/include/ConstantMergedTensorDescriptor.hip.hpp
src/include/ConstantMergedTensorDescriptor.hip.hpp
+44
-3
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+74
-0
src/include/blockwise_generic_tensor_slice_op.hip.hpp
src/include/blockwise_generic_tensor_slice_op.hip.hpp
+147
-13
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+2
-2
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
+2
-2
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp
...dwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp
+318
-0
No files found.
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
0a265731
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#endif
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 1
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
N1
,
N2
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/driver.hip.cpp
View file @
0a265731
...
...
@@ -14,6 +14,7 @@
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
struct
GeneratorTensor_1
{
...
...
@@ -629,6 +630,8 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 1
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
src/include/Array.hip.hpp
View file @
0a265731
...
...
@@ -135,3 +135,16 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
return
result
;
}
template
<
class
TData
,
index_t
NSize
,
class
F
>
__host__
__device__
constexpr
TData
reduce_on_array
(
Array
<
TData
,
NSize
>
a
,
F
f
)
{
TData
result
=
a
[
0
];
static_for
<
1
,
NSize
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
i
=
I
.
Get
();
result
=
f
(
result
,
a
[
i
]);
});
return
result
;
}
src/include/ConstantMergedTensorDescriptor.hip.hpp
View file @
0a265731
...
...
@@ -33,9 +33,10 @@ struct ConstantMergedTensorDescriptor
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfOriginalDimension
()
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
{
return
n
OriginalDim
;
return
std
::
get
<
IDim
>
(
m
OriginalDim
MergeSeqs
)
;
}
template
<
index_t
IDim
>
...
...
@@ -98,7 +99,15 @@ struct ConstantMergedTensorDescriptor
return
original_multi_id
;
}
__host__
__device__
static
index_t
GetOffsetFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
#if 0 // not needed
__host__ __device__ static index_t
GetOffsetFromOriginalMultiIndex(Array<index_t, nOriginalDim> original_multi_id)
{
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
#endif
__host__
__device__
static
index_t
GetOffsetFromMultiIndexA
(
Array
<
index_t
,
nDim
>
multi_id
)
{
const
auto
original_multi_id
=
GetOriginalMultiIndexFromMultiIndex
(
multi_id
);
...
...
@@ -117,6 +126,38 @@ struct ConstantMergedTensorDescriptor
return
dummy_desc
.
GetMultiIndexFrom1dIndex
(
id
);
}
#if 0 // not needed
template <index_t IDim>
__host__ __device__ static index_t GetNewOriginalMultiIndexAfterMovingAlongOneDimension(
Array<index_t, nOriginalDim> old_original_multi_id, Number<IDim>, index_t step_size)
{
auto new_original_multi_id = old_original_multi_id;
// get partial-original-multi-id corresponding to this merged dimension
constexpr auto original_partial_dims = std::get<IDim>(mOriginalDimMergeSeqs);
constexpr auto original_partial_tensor_desc =
OriginalTensorDesc::Extract(original_partial_dims);
auto old_original_partial_multi_id =
extract_array(old_original_mutli_id, original_paritial_dims);
auto new_original_partial_multi_id =
original_partial_tensor_desc.GetNewMultiIndexGivenStepSizeOf1dIndex(
old_original_partial_multi_id, step_size);
// update original-mutli-id
static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){};
constexpr index_t idim_original = original_dims_partial.Get(I);
new_original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
});
return new_original_multi_id;
}
#endif
};
template
<
class
OriginalTensorDesc
,
class
...
OriginalDimMergeSeqs
>
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
0a265731
...
...
@@ -40,6 +40,14 @@ struct ConstantTensorDescriptor
#endif
}
__host__
__device__
static
constexpr
auto
GetOriginalTensorDescriptor
()
{
return
Type
{};
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetContainedOriginalDimensions
(
Number
<
IDim
>
)
{
return
Sequence
<
IDim
>
{};
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
...
...
@@ -66,6 +74,19 @@ struct ConstantTensorDescriptor
return
MemoryRanks
{}.
Get
(
Number
<
I
>
{});
}
__host__
__device__
static
constexpr
bool
AreStridesNonAscending
()
{
bool
flag
=
true
;
static_for
<
0
,
nDim
-
1
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
IDim_p1
=
Number
<
IDim
.
Get
()
+
1
>
{};
flag
=
flag
&&
(
GetLength
(
IDim
)
>=
GetLength
(
IDim_p1
));
});
return
flag
;
}
template
<
class
T
>
__host__
__device__
static
constexpr
bool
ContainMultipleOriginalDimensions
(
T
)
{
...
...
@@ -167,6 +188,46 @@ struct ConstantTensorDescriptor
return
multi_id
;
}
__host__
__device__
static
auto
GetOriginalMultiIndexFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
return
multi_id
;
}
// This function doesn't do carry check on the highest dimension, for performance reason.
// It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound
// on the highest dimension
__host__
__device__
static
Array
<
index_t
,
nDim
>
UpdateMultiIndexGivenStepSizeOf1dIndex
(
Array
<
index_t
,
nDim
>
old_multi_id
,
index_t
step_size_of_1d_index
)
{
auto
new_multi_id
=
old_multi_id
+
GetMultiIndexFrom1dIndex
(
step_size_of_1d_index
);
bool
carry
=
false
;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for
<
0
,
nDim
-
1
,
1
>
{}([
&
](
auto
IDimReverse
)
{
constexpr
index_t
idim
=
nDim
-
1
-
IDimReverse
.
Get
();
constexpr
auto
IDim
=
Number
<
idim
>
{};
if
(
carry
)
{
++
new_multi_id
[
idim
];
}
carry
=
false
;
if
(
new_multi_id
[
idim
]
>=
GetLength
(
IDim
))
{
new_multi_id
[
idim
]
-=
GetLength
(
IDim
);
carry
=
true
;
}
});
return
new_multi_id
;
}
// WRONG! Ranks is broken
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
IDims
>
...
extract_dims
)
...
...
@@ -193,6 +254,19 @@ struct ConstantTensorDescriptor
return
Extract
(
Number
<
IDims
>
{}...);
}
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
Inject
(
ConstantTensorDescriptor
<
Ts
...
>
)
{
using
leaf_tensor
=
ConstantTensorDescriptor
<
Ts
...
>
;
// memory rank is broken
// TODO: remove memory rank info from tensor descritpor
return
ConstantTensorDescriptor
<
decltype
(
GetLengths
().
Append
(
leaf_tensor
::
GetLengths
())),
decltype
(
GetStrides
().
Append
(
leaf_tensor
::
GetStrides
())),
decltype
(
GetMemoryRanks
().
Append
(
leaf_tensor
::
GetMemoryRanks
()))
>
{};
}
template
<
index_t
IDim
,
index_t
SliceLen
>
__host__
__device__
static
constexpr
auto
Slice
(
Number
<
IDim
>
,
Number
<
SliceLen
>
)
{
...
...
src/include/blockwise_generic_tensor_slice_op.hip.hpp
View file @
0a265731
#pragma once
#include "threadwise_tensor_slice_op.hip.hpp"
// slice a (normal or merged) tensor, reorder and copy it into another (normal or merged) tensor
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
...
...
@@ -18,8 +19,29 @@ struct BlockwiseGenericTensorSliceCopy_v1
{
static
constexpr
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
static
constexpr
index_t
nOriginalDimSrc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
static
constexpr
index_t
nOriginalDimDst
=
DstDesc
::
GetOriginalTensorDescriptor
().
GetNumOfDimension
();
// per-thread offset
index_t
mThreadSrcOffset
;
index_t
mThreadDstOffset
;
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
// updated if slicing-window is moved. However, they will not be used if you always move
// the slicing-window along a non-merged dimension. In that case, compiler should be
// able to remove these calculation.
// TODO: make sure compiler would actually remove them in that case
// partial offset in each (merged) dimension
Array
<
index_t
,
nDim
>
mThreadSrcPartialOffsets
;
Array
<
index_t
,
nDim
>
mThreadDstPartialOffsets
;
// multi-id of original tensor
Array
<
index_t
,
nOriginalDimSrc
>
mThreadSrcOriginalMultiId
;
Array
<
index_t
,
nOriginalDimDst
>
mThreadDstOriginalMultiId
;
__device__
BlockwiseGenericTensorSliceCopy_v1
(
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
...
...
@@ -72,7 +94,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
"wrong! only surpport Sub-Length == 1 on a merged dimension"
);
});
// calculate m
SrcMy
ThreadOffset, m
DstMy
ThreadOffset
// calculate mThread
Src
Offset, mThread
Dst
Offset
const
auto
thread_cluster_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
...
...
@@ -81,11 +103,46 @@ struct BlockwiseGenericTensorSliceCopy_v1
const
auto
thread_data_multi_id_begin
=
data_cluster_multi_id
*
SubLengths
{};
mSrcMyThreadOffset
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
// original multi-id
mThreadSrcOriginalMultiId
=
SrcDesc
::
GetOriginalMultiIndexFromMultiIndex
(
src_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
mThreadDstOriginalMultiId
=
DstDesc
::
GetOriginalMultiIndexFromMultiIndex
(
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
// partial offset on each dimension
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
.
Get
();
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
mThreadSrcPartialOffsets
[
idim
]
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
));
});
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim_
)
{
constexpr
auto
IDim
=
decltype
(
IDim_
){};
constexpr
index_t
idim
=
IDim
.
Get
();
constexpr
auto
dst_partial_original_dims
=
DstDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
dst_partial_original_desc
=
DstDesc
::
GetOriginalTensorDescriptor
().
Extract
(
dst_partial_original_dims
);
mThreadDstPartialOffsets
[
idim
]
=
dst_partial_original_desc
.
GetOffsetFromMultiIndex
(
extract_array
(
mThreadDstOriginalMultiId
,
dst_partial_original_dims
));
});
// complete offset
mThreadSrcOffset
=
reduce_on_array
(
mThreadSrcPartialOffsets
,
std
::
plus
<
index_t
>
{});
mThreadDstOffset
=
reduce_on_array
(
mThreadDstPartialOffsets
,
std
::
plus
<
index_t
>
{});
mDstMyThreadOffset
=
DstDesc
::
GetOffsetFromMultiIndex
(
dst_block_data_multi_id_begin
+
thread_data_multi_id_begin
);
#if 0
{
printf("id %5u %5u: "
...
...
@@ -93,7 +150,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
"thread_cluster_multi_id: %u %u %u %u, "
"data_cluster_multi_id: %u %u %u %u, "
"thread_data_multi_id_begin: %u %u %u %u, "
"m
SrcMy
ThreadOffset %u, m
DstMy
ThreadOffset %u \n",
"mThread
Src
Offset %u, mThread
Dst
Offset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
src_block_data_multi_id_begin[0],
...
...
@@ -112,8 +169,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_data_multi_id_begin[1],
thread_data_multi_id_begin[2],
thread_data_multi_id_begin[3],
m
SrcMy
ThreadOffset,
m
DstMy
ThreadOffset);
mThread
Src
Offset,
mThread
Dst
Offset);
}
#endif
}
...
...
@@ -156,7 +213,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
clipboard_data_multi_id_begin
);
// cannot not constexpr, why?
threadwise_generic_tensor_slice_copy
(
SrcDesc
{},
p_src
+
src_offset
+
m
SrcMy
ThreadOffset
,
p_src
+
src_offset
+
mThread
Src
Offset
,
make_zero_array
<
index_t
,
nDim
>
(),
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
...
...
@@ -197,7 +254,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
p_clipboard
+
clipboard_offset
,
make_zero_array
<
index_t
,
nDim
>
(),
DstDesc
{},
p_dst
+
dst_offset
+
m
DstMy
ThreadOffset
,
p_dst
+
dst_offset
+
mThread
Dst
Offset
,
make_zero_array
<
index_t
,
nDim
>
(),
thread_sub_tensor_lengths
,
DstAccessOrder
{});
...
...
@@ -211,4 +268,81 @@ struct BlockwiseGenericTensorSliceCopy_v1
RunLoadRegisterClipboard
(
p_src
,
p_clipboard
);
RunStoreRegisterClipboard
(
p_clipboard
,
p_dst
);
}
// When moving the slicing windows along a merged dimension, if the strides of the
// contained (by the merged dimension) original dimensions are in descending order,
// then there is no guarantee that the new offset will be larger than the old offset
// for movement in positive direction (vice versue for movement in negative direction).
// As a result, there is the possiblity that the offset calculation may result in
// unsigned integer underflow (due to "-" operation). However, this hazard should not
// happen, as long as the users make sure the slicing window would not be moved out of
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
// check on out-of-bound slicing window, for performance reason
template
<
index_t
IDim_
,
index_t
StepSize
,
bool
PositiveDirection
>
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
static_assert
(
PositiveDirection
,
"wrong! only support movement in positive direction for now"
);
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
index_t
idim
=
IDim
.
Get
();
static_if
<
SrcDesc
::
ContainMultipleOriginalDimensions
(
IDim
)
>
{}([
&
](
auto
fwd
)
{
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove uselss calculations
// extract partial original dimensions
constexpr
auto
src_partial_original_dims
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
);
constexpr
auto
src_partial_original_desc
=
SrcDesc
::
GetOriginalTensorDescriptor
().
Extract
(
src_partial_original_dims
);
// calculate new partial original multi-id
auto
old_src_partial_original_multi_id
=
extract_array
(
mThreadSrcOriginalMultiId
,
src_partial_original_dims
);
auto
new_src_partial_original_multi_id
=
src_partial_original_desc
.
UpdateMultiIndexGivenStepSizeOf1dIndex
(
old_src_partial_original_multi_id
,
StepSize
);
// update "mThreadSrcOriginalMultiId"
static_for
<
0
,
src_partial_original_dims
.
GetSize
(),
1
>
{}([
&
](
auto
I_
)
{
constexpr
auto
I
=
decltype
(
I_
){};
constexpr
index_t
idim_original
=
src_partial_original_dims
.
Get
(
I
);
mThreadSrcOriginalMultiId
[
idim_original
]
=
new_src_partial_original_multi_id
[
I
.
Get
()];
});
// calculate new partial offset on this merged dimension
const
index_t
old_src_partial_offset
=
mThreadSrcPartialOffsets
[
idim
];
const
index_t
new_src_partial_offset
=
src_partial_original_desc
.
GetOffsetFromMultiIndex
(
new_src_partial_original_multi_id
);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets
[
idim
]
=
new_src_partial_offset
;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset
=
mThreadSrcOffset
+
new_src_partial_offset
-
old_src_partial_offset
;
}).
Else
([
&
](
auto
fwd
)
{
// Logic for non-merged dimension. If you are never going to move the slicing window on
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
// which are being calculated here, will never be used later. In this case, compiler
// should be able to remove these calculations.
// TODO: make sure compiler would actually remove them in this case.
constexpr
index_t
idim_original
=
SrcDesc
::
GetContainedOriginalDimensions
(
IDim
).
Front
();
mThreadSrcOffset
+=
StepSize
*
SrcDesc
::
GetStride
(
IDim
);
mThreadSrcOriginalMultiId
[
idim_original
]
+=
StepSize
;
mThreadSrcPartialOffsets
[
idim
]
+=
StepSize
*
SrcDesc
::
GetStride
(
IDim
);
});
}
};
src/include/functional.hip.hpp
View file @
0a265731
...
...
@@ -44,7 +44,7 @@ struct static_if<true>
}
template
<
class
F
>
__host__
__device__
static
constexpr
auto
e
lse
_
(
F
)
__host__
__device__
static
constexpr
auto
E
lse
(
F
)
{
return
Type
{};
}
...
...
@@ -62,7 +62,7 @@ struct static_if<false>
}
template
<
class
F
>
__host__
__device__
static
constexpr
auto
e
lse
_
(
F
f
)
__host__
__device__
static
constexpr
auto
E
lse
(
F
f
)
{
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
...
...
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
...
@@ -337,7 +337,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
...
@@ -373,7 +373,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
...
@@ -363,7 +363,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
View file @
0a265731
...
...
@@ -412,7 +412,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
View file @
0a265731
...
...
@@ -432,7 +432,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
View file @
0a265731
...
...
@@ -115,7 +115,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
Number
<
InBlockReorderDataPerWrite_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with
alignment
// TODO: need to properly implement tensor descriptor with
multiple alignment requirements
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
...
...
@@ -417,7 +417,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
View file @
0a265731
...
...
@@ -407,7 +407,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
e
lse
_
([
&
](
auto
f_dummy
)
{
}).
E
lse
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
View file @
0a265731
...
...
@@ -366,7 +366,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
}).
e
lse
_
([
&
](
auto
fwd
)
{
}).
E
lse
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
...
...
src/include/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp
0 → 100644
View file @
0a265731
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMergedTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_generic_tensor_slice_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2
,
class
InBlockCopyClusterLengths_E_N1_B_N2
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
TRUE
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_default_rank_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, H, W]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Hi
>
{})
.
Slice
(
I3
,
Number
<
Wi
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Inject
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_default_rank_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
Sequence
<
0
,
1
,
3
,
2
>
,
// thread_arrange_order [E, N1, N2, B]
Sequence
<
0
,
1
,
3
,
2
>
,
// src_access_order [E, N1, N2, B]
Sequence
<
0
,
1
,
2
,
3
>
,
// dst_access_order [E, N1, B, N2]
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_default_rank_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
Sequence
<
1
,
0
>
,
// thread_arrange_order [K, E]
Sequence
<
1
,
0
>
,
// src_access_order [K, E]
Sequence
<
0
,
1
>
,
// dst_access_order [E, K]
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_e_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
)
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_e_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
// do work
for
(
index_t
e
=
0
;
e
<
E
;
e
+=
EPerBlock
)
{
// marching slicing window
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
TRUE
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
TRUE
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
(
K1
*
K2
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_default_rank_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
SeqType
{});
}
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment