Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
38a90b6e
Unverified
Commit
38a90b6e
authored
Oct 20, 2021
by
Chao Liu
Committed by
GitHub
Oct 20, 2021
Browse files
Merge pull request #43 from ROCmSoftwarePlatform/develop
Merge develop into master
parents
88833bd9
c3018794
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4644 additions
and
330 deletions
+4644
-330
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp
...ic_reduction_second_call_warpwise_reduce_partial_dims.cpp
+279
-0
host/driver_offline/include/debug.hpp
host/driver_offline/include/debug.hpp
+13
-0
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
...ackward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
+11
-34
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
...kward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
+193
-78
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
...d_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
+389
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
...ght_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
+258
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+26
-20
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+290
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...ard_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+276
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+458
-0
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+68
-8
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
+289
-45
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
+263
-0
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
+289
-45
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
+263
-0
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
+290
-46
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
+291
-0
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
+337
-48
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
+347
-0
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+14
-6
No files found.
composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
using
namespace
ck
;
using
srcDataType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_SRC_DATATYPE
)
>::
type
;
using
dstDataType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_DST_DATATYPE
)
>::
type
;
using
compType
=
typename
get_datatype_from_enum
<
static_cast
<
DataTypeEnum_t
>
(
CK_PARAM_REDUCE_COMPTYPE
)
>::
type
;
constexpr
index_t
BlockSize
=
CK_PARAM_BLOCKSIZE
;
// tunable
constexpr
index_t
dstDims
=
CK_PARAM_OUT_DIMS
;
constexpr
ReduceTensorOp_t
op
=
static_cast
<
ReduceTensorOp_t
>
(
CK_PARAM_REDUCE_OP
);
constexpr
NanPropagation_t
nanPropaOpt
=
CK_PARAM_NAN_PROPAGATE
==
0
?
NanPropagation_t
::
NOT_PROPAGATE_NAN
:
NanPropagation_t
::
PROPAGATE_NAN
;
constexpr
ReduceTensorIndices_t
reduceIndicesOpt
=
CK_PARAM_REDUCE_INDICES
==
0
?
ReduceTensorIndices_t
::
NO_INDICES
:
ReduceTensorIndices_t
::
FLATTENED_INDICES
;
constexpr
bool
src2d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_SRC2D_PADDING
);
constexpr
bool
dst1d_need_padding
=
static_cast
<
bool
>
(
CK_PARAM_DST1D_PADDING
);
constexpr
bool
indexable
=
reduce_binary_operator
<
compType
,
op
>::
indexable
;
constexpr
bool
need_indices
=
indexable
&&
(
reduceIndicesOpt
!=
ReduceTensorIndices_t
::
NO_INDICES
);
constexpr
index_t
GredAccessesPerThreadInWarp
=
CK_PARAM_ACCESSES_PER_THREAD_INWARP
;
// tunable
// helper functions using variadic template arguments
template
<
index_t
...
Ns
>
__device__
static
auto
make_tuple_from_array_and_index_seq
(
const
int
*
lengths
,
Sequence
<
Ns
...
>
)
{
return
make_tuple
(
static_cast
<
index_t
>
(
lengths
[
Ns
])...);
};
template
<
index_t
arraySize
>
__device__
static
auto
make_tuple_from_array
(
const
int
*
lengths
,
Number
<
arraySize
>
)
{
static_assert
(
arraySize
>=
1
&&
arraySize
<=
6
,
"The tensor should have 1 to 6 dimensions"
);
constexpr
auto
index_seq
=
typename
arithmetic_sequence_gen
<
0
,
arraySize
,
1
>::
type
{};
return
make_tuple_from_array_and_index_seq
(
lengths
,
index_seq
);
};
template
<
index_t
...
Ns
>
__device__
static
constexpr
auto
make_tuple_from_seq
(
Sequence
<
Ns
...
>
)
{
return
make_tuple
(
Ns
...);
};
extern
"C"
__global__
void
gridwise_generic_reduce_2_prepare
(
int
GridSize
,
int
BlkGroupSize
,
int
outLength0
,
int
outLength1
,
int
outLength2
,
int
outLength3
,
int
outLength4
,
int
outLength5
,
int
outStride0
,
int
outStride1
,
int
outStride2
,
int
outStride3
,
int
outStride4
,
int
outStride5
,
void
*
__restrict__
ws_global
)
{
(
void
)
BlkGroupSize
;
void
*
p_src2dDesc
=
ws_global
;
void
*
p_dst1dDesc
=
static_cast
<
char
*>
(
ws_global
)
+
2048
;
const
int
dstLengths
[
6
]
=
{
outLength0
,
outLength1
,
outLength2
,
outLength3
,
outLength4
,
outLength5
};
const
int
dstStrides
[
6
]
=
{
outStride0
,
outStride1
,
outStride2
,
outStride3
,
outStride4
,
outStride5
};
const
auto
tupleDstLengths
=
make_tuple_from_array
(
dstLengths
,
Number
<
dstDims
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
dstStrides
,
Number
<
dstDims
>
{});
const
auto
dstDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
dst1dDesc
=
transform_tensor_descriptor
(
dstDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
dstDims
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
index_t
invariantLen
=
dst1dDesc
.
GetLength
(
Number
<
0
>
{});
const
index_t
toReduceLen
=
BlkGroupSize
;
auto
src2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLen
,
toReduceLen
));
constexpr
auto
copySliceLen
=
warpSize
*
GredAccessesPerThreadInWarp
;
if
constexpr
(
src2d_need_padding
)
{
const
auto
srcPad1
=
GridSize
*
BlockSize
/
warpSize
-
invariantLen
;
const
auto
srcPad2
=
((
toReduceLen
+
copySliceLen
-
1
)
/
copySliceLen
)
*
copySliceLen
-
toReduceLen
;
auto
src2dDesc_2
=
transform_tensor_descriptor
(
src2dDesc
,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
srcPad1
),
make_pad_transform
(
toReduceLen
,
0
,
srcPad2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc_2
)
*>
(
p_src2dDesc
)
=
src2dDesc_2
;
}
else
{
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
src2dDesc
)
*>
(
p_src2dDesc
)
=
src2dDesc
;
}
if
constexpr
(
dst1d_need_padding
)
{
const
auto
dstPad
=
GridSize
*
BlockSize
/
warpSize
-
invariantLen
;
auto
dst1dDesc_2
=
transform_tensor_descriptor
(
dst1dDesc
,
make_tuple
(
make_pad_transform
(
invariantLen
,
0
,
dstPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc_2
)
*>
(
p_dst1dDesc
)
=
dst1dDesc_2
;
}
else
{
if
(
get_thread_local_1d_id
()
==
0
)
*
static_cast
<
decltype
(
dst1dDesc
)
*>
(
p_dst1dDesc
)
=
dst1dDesc
;
}
};
template
<
index_t
dstDims
>
struct
get_ref_desc_types
{
static
constexpr
auto
ref_tupleDstLengths
=
make_tuple_from_seq
(
typename
uniform_sequence_gen
<
dstDims
,
8
>::
type
{});
static
constexpr
auto
ref_dstDesc
=
make_naive_tensor_descriptor
(
ref_tupleDstLengths
,
ref_tupleDstLengths
);
static
constexpr
auto
ref_dst1dDesc
=
transform_tensor_descriptor
(
ref_dstDesc
,
make_tuple
(
make_merge_transform
(
ref_tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
dstDims
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
static
constexpr
index_t
ref_invariantLen
=
ref_dst1dDesc
.
GetLength
(
Number
<
0
>
{});
static
constexpr
index_t
ref_toReduceLen
=
8
;
static
constexpr
auto
ref_src2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
ref_invariantLen
,
ref_toReduceLen
));
using
refType_src2dDesc
=
decltype
(
ref_src2dDesc
);
using
refType_dst1dDesc
=
decltype
(
ref_dst1dDesc
);
// used by the DirectThreadWise and DirectWarpWise method
using
refType_src2dDesc_padded_12
=
decltype
(
transform_tensor_descriptor
(
ref_src2dDesc
,
make_tuple
(
make_pad_transform
(
ref_invariantLen
,
0
,
2
),
make_pad_transform
(
ref_toReduceLen
,
0
,
2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{})));
using
refType_dst1dDesc_padded
=
decltype
(
transform_tensor_descriptor
(
ref_dst1dDesc
,
make_tuple
(
make_pad_transform
(
ref_invariantLen
,
0
,
2
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{})));
};
using
refType_src2dDesc
=
typename
get_ref_desc_types
<
dstDims
>::
refType_src2dDesc
;
using
refType_dst1dDesc
=
typename
get_ref_desc_types
<
dstDims
>::
refType_dst1dDesc
;
using
refType_src2dDesc_padded_12
=
typename
get_ref_desc_types
<
dstDims
>::
refType_src2dDesc_padded_12
;
using
refType_dst1dDesc_padded
=
typename
get_ref_desc_types
<
dstDims
>::
refType_dst1dDesc_padded
;
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_src2d_descriptor
(
const
void
*
p_src2dDesc
)
{
if
constexpr
(
need_padding
)
return
(
*
reinterpret_cast
<
const
refType_src2dDesc_padded_12
*>
(
p_src2dDesc
));
else
return
(
*
reinterpret_cast
<
const
refType_src2dDesc
*>
(
p_src2dDesc
));
};
template
<
bool
need_padding
>
static
__device__
auto
get_reduction_dst1d_descriptor
(
const
void
*
p_dst1dDesc
)
{
if
constexpr
(
need_padding
)
return
(
*
reinterpret_cast
<
const
refType_dst1dDesc_padded
*>
(
p_dst1dDesc
));
else
return
(
*
reinterpret_cast
<
const
refType_dst1dDesc
*>
(
p_dst1dDesc
));
};
extern
"C"
__global__
void
gridwise_generic_reduce_2
(
int
origReduceLen
,
float
alpha
,
const
void
*
__restrict__
p_src_global
,
float
beta
,
void
*
__restrict__
p_dst_global
,
const
void
CONSTANT
*
ws_global
,
long
ws_buf2_bytes_offset
,
void
*
__restrict__
indices_global
)
{
(
void
)
p_src_global
;
const
void
*
p_src2dDesc
=
cast_pointer_to_generic_address_space
(
ws_global
);
const
void
*
p_dst1dDesc
=
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
2048
;
void
*
ws_buf1_global
=
const_cast
<
char
*>
(
static_cast
<
const
char
*>
(
p_src2dDesc
)
+
4096
);
const
auto
src2dDesc
=
get_reduction_src2d_descriptor
<
src2d_need_padding
>
(
p_src2dDesc
);
const
auto
dst1dDesc
=
get_reduction_dst1d_descriptor
<
dst1d_need_padding
>
(
p_dst1dDesc
);
using
gridwise_2d_reduce
=
GridwiseReduction_xy_to_x_direct_warpwise
<
BlockSize
,
srcDataType
,
dstDataType
,
compType
,
decltype
(
src2dDesc
),
decltype
(
dst1dDesc
),
op
,
nanPropaOpt
,
reduceIndicesOpt
,
false
,
true
,
GredAccessesPerThreadInWarp
>
;
void
*
const
ws_buf2_global
=
ws_buf2_bytes_offset
>
0
?
static_cast
<
void
*>
(
static_cast
<
char
*>
(
ws_buf1_global
)
+
ws_buf2_bytes_offset
)
:
nullptr
;
constexpr
int
RunId
=
need_indices
?
3
:
1
;
gridwise_2d_reduce
::
template
Run
<
RunId
>(
src2dDesc
,
dst1dDesc
,
origReduceLen
,
alpha
,
static_cast
<
const
srcDataType
*
const
__restrict__
>
(
ws_buf1_global
),
beta
,
static_cast
<
dstDataType
*
const
__restrict__
>
(
p_dst_global
),
static_cast
<
const
int
*
const
__restrict__
>
(
ws_buf2_global
),
static_cast
<
int
*
const
__restrict__
>
(
indices_global
));
};
host/driver_offline/include/debug.hpp
0 → 100644
View file @
38a90b6e
#ifndef DEBUG_HPP
#define DEBUG_HPP
namespace
debug
{
namespace
debug_driver_gemm_xdlops_v2r3
{
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
static
ck
::
index_t
M01
=
1
;
static
ck
::
index_t
N01
=
1
;
}
// namespace debug_driver_gemm_xdlops_v2r3
}
// namespace debug
#endif
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
View file @
38a90b6e
...
...
@@ -3,6 +3,7 @@
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
...
...
@@ -48,8 +49,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if
1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
#if
0
// [M, N, K0, K1] = [128, 128, 4, 4]
, C = 64,
for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
...
...
@@ -76,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif
1
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -105,7 +106,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
...
@@ -133,7 +134,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -159,34 +160,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#endif
...
...
@@ -294,13 +267,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
View file @
38a90b6e
...
...
@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4]
, C = 128,
for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
...
...
@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4]
, C = 64,
for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
#elif
0
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
...
@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -160,23 +160,91 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
out_n_ho_wo_k_desc
,
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
I0
,
I0
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
...
...
@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-:
// gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
...
...
@@ -215,7 +284,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
//clang-format on
//
clang-format on
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
...
...
@@ -225,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
float
ave_time
=
0
;
for
(
index_t
i_ytilda
=
0
;
i_ytilda
<
YTilda
;
++
i_ytilda
)
{
for
(
index_t
i_xtilda
=
0
;
i_xtilda
<
XTilda
;
++
i_xtilda
)
{
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
out_n_ho_wo_k_desc
,
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
i_ytilda
,
i_xtilda
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
GemmK0
=
out_gemmk0_gemmm_gemmk1_grid_desc
.
GetLength
(
I0
);
if
(
GemmK0
!=
0
)
{
ave_time
+=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
#endif
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
}
}
}
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
...
...
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
0 → 100644
View file @
38a90b6e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
,
const
InLeftPads
&
,
const
InRightPads
&
,
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
const
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: Gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
>
{}));
// 2-: Gemmk1
// clang-format off
constexpr
auto
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
// clang-format on
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1
(
out_n_ho_wo_k_desc
,
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
#endif
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf
.
FromDevice
(
in_n_hi_wi_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
38a90b6e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_c_hi_wi_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k_ho_wo_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_desc
.
GetLength
(
I3
);
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmK1
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
BatchLen
=
std
::
ceil
(
GemmK
*
1.0
/
(
GemmKPerBlock
*
GemmKBatch
));
const
index_t
GemmK0
=
BatchLen
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemB
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
true
,
true
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_c_y_x_device_buf
.
FromDevice
(
wei_k_c_y_x
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
38a90b6e
...
...
@@ -4,7 +4,8 @@
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
...
...
@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
Wei
>&
in_n_c_hi_wi
,
Tensor
<
T
In
Wei
>&
wei_k_c_y_x
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ck
::
index_t
nrepeat
)
{
...
...
@@ -35,8 +36,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
Wei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
T
In
Wei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
...
...
@@ -47,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if
1
#if
0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
...
...
@@ -164,9 +165,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TIn
Wei
,
TIn
,
TAcc
,
T
Out
,
T
Wei
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
...
...
@@ -203,18 +204,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
38a90b6e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_hi_wi_c_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_desc
.
GetLength
(
I2
);
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmK1
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
BatchLen
=
std
::
ceil
(
GemmK
*
1.0
/
(
GemmKPerBlock
*
GemmKBatch
));
const
index_t
GemmK0
=
BatchLen
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmKBatch
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmKBatch
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
true
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_y_x_c_device_buf
.
FromDevice
(
wei_k_y_x_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
38a90b6e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
true
>
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
wei_k_y_x_c_device_buf
.
FromDevice
(
wei_k_y_x_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
38a90b6e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
16
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
16
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
16
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_hi_wi_c_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_desc
.
GetLength
(
I2
);
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmK1
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
BatchLen
=
std
::
ceil
(
GemmK
*
1.0
/
(
GemmKPerBlock
*
GemmKBatch
));
const
index_t
GemmK0
=
BatchLen
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
true
>
;
// timing
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// verification
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_y_x_c_device_buf
.
FromDevice
(
wei_k_y_x_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
38a90b6e
...
...
@@ -49,7 +49,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4]
, C = 128,
for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
...
...
@@ -77,7 +77,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4]
, C = 128,
for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
// [M, N, K0, K1] = [256, 256, 4, 8]
, C = 256,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
...
@@ -133,7 +133,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
...
@@ -160,8 +160,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -189,7 +189,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8]
, C = 64,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
...
...
@@ -316,13 +372,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
decltype
(
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
View file @
38a90b6e
...
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_km_kn_mn
(
const
ADesc
&
a_k_m_grid_desc
,
const
BDesc
&
b_k_n_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_k_m
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_kn_mn
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
...
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
...
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
...
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_k_m
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
M
=
a_k_m
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
N
=
b_k_n
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
K
=
a_k_m
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
M
=
a_k_m
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
N
=
b_k_n
.
mD
esc
.
GetLength
s
()[
1
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_k_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
0 → 100644
View file @
38a90b6e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_kn_nm
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
View file @
38a90b6e
...
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_km_nk_mn
(
const
ADesc
&
a_k_m_grid_desc
,
const
BDesc
&
b_n_k_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_k_m
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_nk_mn
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
...
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
...
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
...
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_k_m
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
M
=
a_k_m
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
N
=
b_n_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
K
=
a_k_m
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
M
=
a_k_m
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
N
=
b_n_k
.
mD
esc
.
GetLength
s
()[
0
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_n_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
0 → 100644
View file @
38a90b6e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_nk_nm
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
View file @
38a90b6e
...
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_mk_kn_mn
(
const
ADesc
&
a_m_k_grid_desc
,
const
BDesc
&
b_k_n_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_m_k
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_kn_mn
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
...
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
...
@@ -33,8 +22,148 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
...
@@ -88,46 +217,157 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_m_k
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
M
=
a_m_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
N
=
b_k_n
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
K
=
a_m_k
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
M
=
a_m_k
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
N
=
b_k_n
.
mD
esc
.
GetLength
s
()[
1
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_m_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
M
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_k_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
0 → 100644
View file @
38a90b6e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_kn_nm
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
View file @
38a90b6e
...
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_mk_nk_mn
(
const
ADesc
&
a_m_k_grid_desc
,
const
BDesc
&
b_n_k_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_m_k
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_nk_mn
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
...
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
...
@@ -34,6 +23,34 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -60,9 +77,93 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
...
@@ -90,7 +191,7 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
...
...
@@ -117,8 +218,36 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
...
...
@@ -144,46 +273,131 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_m_k
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
M
=
a_m_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
N
=
b_n_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
K
=
a_m_k
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
M
=
a_m_k
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
N
=
b_n_k
.
mD
esc
.
GetLength
s
()[
0
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
#if 1
// non-padded GEMM
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_m_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
M
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_n_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -203,9 +417,80 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
#else
// padded GEMM
const
auto
a_k0_m_k1_grid_desc_tmp
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
MRightPad
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
-
M
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc_tmp
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_right_pad_transform
(
M
,
MRightPad
),
make_pass_through_transform
(
K1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc_tmp
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc_tmp
,
make_tuple
(
make_right_pad_transform
(
M
,
MRightPad
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
#endif
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -250,13 +535,17 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
0 → 100644
View file @
38a90b6e
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_nk_nm
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
38a90b6e
#ifndef DRIVER_GEMM_XDLOPS_V2R3
#define DRIVER_GEMM_XDLOPS_V2R3
#ifndef DRIVER_GEMM_XDLOPS_V2R3
_HPP
#define DRIVER_GEMM_XDLOPS_V2R3
_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
...
...
@@ -46,13 +46,17 @@ template <ck::index_t BlockSize,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
...
...
@@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
>
;
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
{
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
...
...
@@ -123,7 +129,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
M01
,
N01
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
...
...
@@ -134,7 +141,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment