Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2724c519
Commit
2724c519
authored
Feb 24, 2024
by
Jing Zhang
Browse files
merge develop
parents
1fb4a474
2eb74a9c
Changes
470
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2659 additions
and
196 deletions
+2659
-196
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+2
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
...ration/gpu/device/impl/device_batchnorm_backward_impl.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+188
-82
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
...pu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
+716
-0
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-5
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+646
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+852
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+89
-91
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+89
-0
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+8
-2
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+6
-1
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+5
-0
No files found.
Too many changes to show.
To preserve performance only
470 of 470+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
2724c519
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -723,9 +723,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -723,9 +723,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
Print
();
arg
.
Print
();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
ck
::
is_xdl_supported
())
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
2724c519
...
@@ -63,7 +63,7 @@ __global__ void
...
@@ -63,7 +63,7 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -613,8 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -613,8 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
ck
::
is_xdl_supported
())
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
2724c519
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0
__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
karg
.
Batch
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
karg
.
Batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
...
@@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
GemmSpecialization
::
MNPadding
,
GemmSpecialization
::
MN
K
Padding
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
...
@@ -310,7 +310,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
...
@@ -310,7 +310,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Problem
&
problem
)
static
bool
IsSupportedArgument
(
const
Problem
&
problem
)
{
{
if
(
problem
.
K
%
K1
!=
0
)
if
(
!
ck
::
is_xdl_supported
()
)
{
{
return
false
;
return
false
;
}
}
...
@@ -411,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
...
@@ -411,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
">"
<<
">"
<<
" NumGemmKPrefetchStage: "
<<
" NumGemmKPrefetchStage: "
<<
NumGemmKPrefetchStage
<<
", "
<<
NumGemmKPrefetchStage
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
View file @
2724c519
...
@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
...
@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
return
(
workspace_size
);
return
(
workspace_size
);
};
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
2724c519
...
@@ -10,12 +10,14 @@
...
@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
{
const
auto
grid_desc_m_g
=
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
make_
tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
make_
tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 1
28
// we want the blkGroupSize be not more than 1
6
if
(
testBlkGroupSize
<=
1
28
)
if
(
testBlkGroupSize
<=
1
6
)
break
;
break
;
iterations
++
;
iterations
++
;
...
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void
*
workspace_mean_
;
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
void
*
workspace_count_
;
void
*
control_
;
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
...
@@ -340,12 +344,19 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -340,12 +344,19 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count
// workspace for welford intermediate count
workspace_size
+=
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size
+=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
}
}
return
(
workspace_size
);
return
(
workspace_size
);
};
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
...
@@ -353,7 +364,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -353,7 +364,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
{
// setup buffer used for intermediate welford mean
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
...
@@ -374,6 +384,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -374,6 +384,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
index_t
count_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
);
count_space_sz
=
math
::
integer_least_multiple
(
count_space_sz
,
64
);
pArg_
->
control_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_count_
)
+
count_space_sz
;
index_t
control_space_sz
=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
hip_check_error
(
hipMemset
(
pArg_
->
control_
,
0
,
control_space_sz
));
};
};
};
};
...
@@ -402,6 +424,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -402,6 +424,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockBatchNormForward_
=
GridwiseMultiblockBatchNormForward
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
using
GridwiseMultiblockWelfordFirstHalf_
=
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -441,78 +489,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
...
@@ -441,78 +489,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
MeanVarSrcDstVectorSize
>
;
index_t
numMeanVarCountBlockTileIteration
=
// It is found that:
(
arg
.
blkGroupSize_
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
const
auto
kern_multiblock_welford_first_half
=
// 2) Profiler on gfx908 could hang even though it works when running examples
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
// 3) Single-kernel method works on gfx1100, but the performance it not better
XDataType
,
// than two-kernel method (due to more warps participating the barrier)
MeanVarDataType
,
if
(
ck
::
get_device_name
()
==
"gfx90a"
)
XYGridDesc_M_K
,
{
MeanVarCountGridDesc_M_G
,
const
auto
kern_multiblock_batchnorm_fwd_
=
GetReduceCountPerThreadFunctor
>
;
kernel_multiblock_batchnorm_forward
<
GridwiseMultiblockBatchNormForward_
,
XDataType
,
const
auto
kern_welford_second_half_batchnorm_forward_final
=
YDataType
,
kernel_welford_second_half_batchnorm_forward_final
<
AccDataType
,
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
ScaleDataType
,
XDataType
,
BiasDataType
,
YDataType
,
MeanVarDataType
,
AccDataType
,
YElementwiseOp
,
ScaleDataType
,
XYGridDesc_M_K
,
BiasDataType
,
MeanVarCountGridDesc_M_G
,
MeanVarDataType
,
MeanVarCountGridDesc_M_K
,
YElementwiseOp
,
ScaleBiasMeanVarGridDesc_M
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
GetReduceCountPerThreadFunctor
>
;
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
kern_multiblock_batchnorm_fwd_
,
launch_and_time_kernel
(
stream_config
,
dim3
(
arg
.
gridSize_
),
kern_multiblock_welford_first_half
,
dim3
(
BlockSize
),
dim3
(
arg
.
gridSize_
),
0
,
dim3
(
BlockSize
),
arg
.
x_grid_desc_m_k_
,
0
,
arg
.
y_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
// for writing to mean/variance/count
mean_var_count_grid_desc_m_g
,
// workspace by multiple workgroups
get_reduce_count_per_thread
,
mean_var_count_grid_desc_m_k
,
// for reading from mean/variance/count
arg
.
numBlockTileIteration_
,
// workspace by each workgroup
arg
.
p_x_
,
arg
.
scale_grid_desc_m_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
arg
.
bias_grid_desc_m_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
arg
.
mean_var_grid_desc_m_
,
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
avg_time
+=
arg
.
epsilon_
,
launch_and_time_kernel
(
stream_config
,
arg
.
p_x_
,
kern_welford_second_half_batchnorm_forward_final
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
dim3
(
arg
.
gridSize_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
dim3
(
BlockSize
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
0
,
static_cast
<
int
*>
(
arg
.
control_
),
arg
.
x_grid_desc_m_k_
,
arg
.
p_scale_
,
arg
.
y_grid_desc_m_k_
,
arg
.
p_bias_
,
mean_var_count_grid_desc_m_k
,
arg
.
y_elementwise_op_
,
arg
.
scale_grid_desc_m_
,
arg
.
p_y_
,
arg
.
bias_grid_desc_m_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
mean_var_grid_desc_m_
,
arg
.
averageFactor_
,
arg
.
blkGroupSize_
,
arg
.
resultRunningMean_
,
arg
.
numBlockTileIteration_
,
arg
.
resultRunningVariance_
,
numMeanVarCountBlockTileIteration
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
epsilon_
,
arg
.
resultSaveMean_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
arg
.
resultSaveInvVariance_
);
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
}
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
else
arg
.
p_x_
,
{
arg
.
p_scale_
,
const
auto
kern_multiblock_welford_first_half
=
arg
.
p_bias_
,
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
arg
.
y_elementwise_op_
,
XDataType
,
arg
.
p_y_
,
MeanVarDataType
,
arg
.
updateMovingAverage_
,
XYGridDesc_M_K
,
arg
.
averageFactor_
,
MeanVarCountGridDesc_M_G
,
arg
.
resultRunningMean_
,
GetReduceCountPerThreadFunctor
>
;
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
const
auto
kern_welford_second_half_batchnorm_forward_final
=
arg
.
resultSaveMean_
,
kernel_welford_second_half_batchnorm_forward_final
<
arg
.
resultSaveInvVariance_
);
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
0 → 100644
View file @
2724c519
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeXY2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>&
xyStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleXYLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleXYStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
xyStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleXYLengths
,
tupleXYStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumBatchNormReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
workSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
workSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_g_padded
=
transform_tensor_descriptor
(
grid_desc_m_g
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_g_padded
);
};
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
math
::
integer_least_multiple
(
reduceLength
,
KThreadClusterSize
)
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeScaleBiasMeanVar1dDescriptor
(
const
std
::
array
<
index_t
,
NumInvariantDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>&
strides
)
{
const
auto
tupleLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_m
=
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_m_padded
);
};
using
XYGridDesc_M_K
=
decltype
(
MakeXY2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
ScaleBiasMeanVarGridDesc_M
=
decltype
(
MakeScaleBiasMeanVar1dDescriptor
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
p_scale
,
const
BiasDataType
*
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
double
epsilon
,
YDataType
*
p_y
,
MeanVarDataType
*
resultSaveMean
,
MeanVarDataType
*
resultSaveInvVariance
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_scale_
(
p_scale
),
p_bias_
(
p_bias
),
y_elementwise_op_
(
y_elementwise_op
),
p_y_
(
p_y
),
resultSaveMean_
(
resultSaveMean
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
yStrides
,
reduceDims
);
std
::
tie
(
invariant_length_
,
reduce_length_
)
=
get_2d_lengths
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
updateMovingAverage_
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
saveMeanInvVariance_
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance_
!=
nullptr
);
if
(
UseMultiblockInK
)
{
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 16
if
(
testBlkGroupSize
<=
16
)
break
;
iterations
++
;
};
blkGroupSize_
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration_
=
iterations
;
}
else
{
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_length_
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize_
=
(
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
blkGroupSize_
;
x_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
scale_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnScaleStrides_
);
bias_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnBiasStrides_
);
mean_var_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnMeanVarStrides_
);
}
AccDataType
epsilon_
;
AccDataType
averageFactor_
;
bool
updateMovingAverage_
;
bool
saveMeanInvVariance_
;
std
::
array
<
index_t
,
Rank
>
xyLengths_
;
std
::
array
<
index_t
,
Rank
>
xStrides_
;
std
::
array
<
index_t
,
Rank
>
yStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
p_scale_
;
const
BiasDataType
*
p_bias_
;
const
YElementwiseOp
y_elementwise_op_
;
YDataType
*
p_y_
;
MeanVarDataType
*
resultSaveMean_
;
MeanVarDataType
*
resultSaveInvVariance_
;
MeanVarDataType
*
resultRunningMean_
;
MeanVarDataType
*
resultRunningVariance_
;
long_index_t
invariant_length_
;
long_index_t
reduce_length_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
XYGridDesc_M_K
x_grid_desc_m_k_
;
XYGridDesc_M_K
y_grid_desc_m_k_
;
ScaleBiasMeanVarGridDesc_M
scale_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
bias_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
mean_var_grid_desc_m_
;
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// workspace for welford intermediate mean
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
}
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
workspace_variance_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
};
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0
;
if
(
UseMultiblockInK
&&
arg
.
blkGroupSize_
>
1
)
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForMultiblockWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
const
auto
mean_var_count_grid_desc_m_g
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountOutputMG2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
const
auto
mean_var_count_grid_desc_m_k
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountInputMK2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordSecondHalfBatchNormForwardFinal_
=
GridwiseWelfordSecondHalfBatchNormForwardFinal
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
}
else
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForBlockwiseWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
using
GridwiseBatchNormForwardWithBlockwiseWelford_
=
GridwiseBatchNormForwardWithBlockwiseWelford
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_batchnorm_fwd
=
kernel_batchnorm_forward_with_blockwise_welford
<
GridwiseBatchNormForwardWithBlockwiseWelford_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_batchnorm_fwd
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
pArg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
pArg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
pArg
)
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
XSrcYDstVectorDim
==
0
)
{
if
(
pArg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
yStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
}
else
{
if
(
pArg_
->
xStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
yStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
};
if
(
pArg_
->
bnScaleStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
ScaleSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnBiasStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
BiasSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
ScaleSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
BiasSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnMeanVarStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
MeanVarSrcDstVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
MeanVarSrcDstVectorSize
!=
0
)
return
false
;
bool
is_valid
=
true
;
static_for
<
0
,
NumInvariantDim
,
1
>
{}([
&
](
auto
I
)
{
if
(
pArg_
->
xyLengths_
[
I
]
!=
pArg_
->
bnScaleBiasMeanVarLengths_
[
I
])
is_valid
=
false
;
});
if
(
!
is_valid
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_scale
,
const
void
*
p_bias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
,
double
averageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
y_elementwise_op
,
epsilon
,
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
MeanVarDataType
*>
(
resultSaveMean
),
static_cast
<
MeanVarDataType
*>
(
resultSaveInvVariance
),
averageFactor
,
static_cast
<
MeanVarDataType
*>
(
resultRunningMean
),
static_cast
<
MeanVarDataType
*>
(
resultRunningVariance
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchNormFwdImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XSrcYDstVectorDim_"
<<
XSrcYDstVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_scale_"
<<
ScaleSrcVectorSize
<<
"_bias_"
<<
BiasSrcVectorSize
<<
"_mean_var_"
<<
MeanVarSrcDstVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
2724c519
...
@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
...
@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
>
;
ADataType
,
BDataType
,
CDataType
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
}
}
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
>
;
ADataType
,
BDataType
,
CDataType
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -448,6 +455,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -448,6 +455,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
0 → 100644
View file @
2724c519
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Column to Image:
// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
// output : input image [G, N, Di, Hi, Wi, C]
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
// output : input image [N, Di, Hi, Wi, G, C]
template
<
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
InputDataType
,
typename
OutputDataType
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
KPerBlock
,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
DeviceColumnToImageImpl
:
public
DeviceConvTensorRearrange
<
NDimSpatial
,
ImageLayout
,
InputDataType
,
OutputDataType
,
conv_tensor_rearrange_op
::
ColumnToImage
>
{
static
constexpr
bool
is_NSpatialGC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
;
static
constexpr
bool
is_GNSpatialC
=
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
I0
>
{};
static
constexpr
auto
YIdx
=
NDimSpatial
==
1
?
I0
:
Number
<
NDimSpatial
-
I2
>
{};
static
constexpr
auto
XIdx
=
Number
<
NDimSpatial
-
I1
>
{};
static
constexpr
auto
spatial_offset
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpecialization
::
MKPadding
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
0
/* NPerBlock*/
,
KPerBlock
};
// Calculate number of independent filters for given conv params
static
index_t
GetNumberOfIndependentFilters
(
const
index_t
input_spatial_len
,
const
index_t
left_pad
,
const
index_t
right_pad
,
const
index_t
filter_len
,
const
index_t
filter_stride
,
const
index_t
filter_dilation
,
const
index_t
image_offset
)
{
const
index_t
x_eff
=
(
filter_len
-
1
)
*
filter_dilation
+
1
;
const
index_t
next_filter_padded
=
math
::
integer_divide_ceil
(
x_eff
,
filter_stride
)
*
filter_stride
;
// If filter_stride >= x_eff then each filter is independent
const
index_t
independent_filter_stride
=
filter_stride
>=
x_eff
?
filter_stride
:
next_filter_padded
;
const
index_t
w_eff
=
input_spatial_len
-
image_offset
+
left_pad
+
right_pad
-
x_eff
;
// There are no independent filters
if
(
w_eff
<
0
)
return
0
;
const
index_t
independent_kernels_num
=
w_eff
/
independent_filter_stride
+
1
;
return
independent_kernels_num
;
}
// Make column form descriptor
static
auto
MakeInputDescriptor_M_K
(
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
independent_filters
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
effs
)
{
const
index_t
DoHoWo
=
ck
::
accumulate_n
<
index_t
>
(
output_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
filter_spatial_lengths
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
NStride
=
DoHoWo
*
gemm_g_m_k_strides
[
I1
]
*
gemm_g_m_k_strides
[
I2
];
// Calculate the appropriate stride for each set of independent filters
// in each dimension
const
index_t
WStride
=
math
::
integer_divide_ceil
(
effs
[
XIdx
],
conv_filter_strides
[
XIdx
])
*
gemm_g_m_k_strides
[
I1
];
const
index_t
HStride
=
math
::
integer_divide_ceil
(
effs
[
YIdx
],
conv_filter_strides
[
YIdx
])
*
output_spatial_lengths
[
XIdx
]
*
gemm_g_m_k_strides
[
I1
];
const
index_t
DStride
=
math
::
integer_divide_ceil
(
effs
[
ZIdx
],
conv_filter_strides
[
ZIdx
])
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
]
*
gemm_g_m_k_strides
[
I1
];
// Create descriptor for independent filters in each dimension and
// then merge them into column form
if
constexpr
(
NDimSpatial
==
1
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
WStride
,
gemm_g_m_k_strides
[
I2
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
XIdx
])),
make_pass_through_transform
(
CZYX
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_gemm_form_merged_filters
);
return
desc_m_k
;
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
YIdx
],
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
HStride
,
WStride
,
gemm_g_m_k_strides
[
I2
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
YIdx
],
independent_filters
[
XIdx
])),
make_pass_through_transform
(
CZYX
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_gemm_form_merged_filters
);
return
desc_m_k
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
auto
desc_gemm_form
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
independent_filters
[
ZIdx
],
independent_filters
[
YIdx
],
independent_filters
[
XIdx
],
CZYX
),
make_tuple
(
NStride
,
DStride
,
HStride
,
WStride
,
gemm_g_m_k_strides
[
I2
]));
const
auto
desc_gemm_form_merged_filters
=
transform_tensor_descriptor
(
desc_gemm_form
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
independent_filters
[
ZIdx
],
independent_filters
[
YIdx
],
independent_filters
[
XIdx
])),
make_pass_through_transform
(
CZYX
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
desc_gemm_form_merged_filters
);
return
desc_m_k
;
}
}
// Use MakeADescriptor_M_K from grouped convolution forward
static
auto
MakeOutDescriptor_M_K
(
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
image_offsets
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
independent_filters
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
effs
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{
1
};
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{
1
};
std
::
array
<
index_t
,
NDimSpatial
+
3
>
c_g_n_k_wos_lengths
{
1
};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
,
index_t
dst_offset
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
()
+
dst_offset
);
};
copy
(
input_spatial_lengths
,
a_g_n_c_wis_lengths
,
spatial_offset
);
copy
(
filter_spatial_lengths
,
b_g_k_c_xs_lengths
,
spatial_offset
);
// Calculate descriptor only for independent filters
copy
(
independent_filters
,
c_g_n_k_wos_lengths
,
spatial_offset
);
// fill only significant values (C and N)
a_g_n_c_wis_lengths
[
I1
]
=
N
;
a_g_n_c_wis_lengths
[
I2
]
=
C
;
b_g_k_c_xs_lengths
[
I2
]
=
C
;
c_g_n_k_wos_lengths
[
I1
]
=
N
;
// Modify pads to apply offsets
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_with_offset
;
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
input_left_pads_with_offset
[
i
]
=
math
::
max
(
0
,
input_left_pads
[
i
]
-
image_offsets
[
i
]);
}
// Modify input spatial lengths to apply offsets
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
a_g_n_c_wis_lengths
[
i
+
spatial_offset
]
-=
math
::
max
(
0
,
image_offsets
[
i
]
-
input_left_pads
[
i
]);
}
// Strides to next independent filters
std
::
array
<
index_t
,
NDimSpatial
>
independent_filter_strides
;
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
index_t
independent_filter_stride
=
math
::
integer_divide_ceil
(
effs
[
i
],
conv_filter_strides
[
i
])
*
conv_filter_strides
[
i
];
// If conv stride is greater than whole filter size, use conv stride
independent_filter_strides
[
i
]
=
conv_filter_strides
[
i
]
>=
effs
[
i
]
?
conv_filter_strides
[
i
]
:
independent_filter_stride
;
}
// Calculate image form descriptor for the modified convolution problem
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ImageLayout
>(
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
c_g_n_k_wos_lengths
,
{},
// not needed for A Descriptor
// conv_filter_strides,
independent_filter_strides
,
conv_filter_dilations
,
input_left_pads_with_offset
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
using
InputGridDesc
=
remove_cvref_t
<
decltype
(
MakeInputDescriptor_M_K
(
1
,
1
,
{},
{},
{},
{},
{},
{}))
>
;
using
OutputGridDesc
=
remove_cvref_t
<
decltype
(
MakeOutDescriptor_M_K
(
1
,
1
,
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
InputGridDesc
>
(
InputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Add
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
:
G_
(
G
),
C_
(
C
),
X_
(
filter_spatial_lengths
[
NDimSpatial
-
I1
]),
p_in_
{
static_cast
<
const
InputDataType
*>
(
p_in
)},
p_out_
{
static_cast
<
OutputDataType
*>
(
p_out
)},
image_g_n_c_wis_strides_
{
image_g_n_c_wis_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
gemm_g_m_k_strides
[
I0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
image_g_n_c_wis_strides
[
I0
];
const
index_t
x_eff
=
(
filter_spatial_lengths
[
XIdx
]
-
1
)
*
conv_filter_dilations
[
XIdx
]
+
1
;
const
index_t
y_eff
=
NDimSpatial
<
2
?
I1
:
(
filter_spatial_lengths
[
YIdx
]
-
1
)
*
conv_filter_dilations
[
YIdx
]
+
1
;
const
index_t
z_eff
=
NDimSpatial
<
3
?
I1
:
(
filter_spatial_lengths
[
ZIdx
]
-
1
)
*
conv_filter_dilations
[
ZIdx
]
+
1
;
// Iterate over sets of independent filters
for
(
int
z_img_offset
=
0
;
z_img_offset
<
z_eff
;
z_img_offset
+=
conv_filter_strides
[
ZIdx
])
{
for
(
int
y_img_offset
=
0
;
y_img_offset
<
y_eff
;
y_img_offset
+=
conv_filter_strides
[
YIdx
])
{
for
(
int
x_img_offset
=
0
;
x_img_offset
<
x_eff
;
x_img_offset
+=
conv_filter_strides
[
XIdx
])
{
std
::
array
<
index_t
,
NDimSpatial
>
image_offsets
;
std
::
array
<
index_t
,
NDimSpatial
>
effs
;
// Calculate the starting offset for a given set of
// independent filters
if
constexpr
(
NDimSpatial
==
1
)
{
image_offsets
=
{
x_img_offset
};
effs
=
{
x_eff
};
}
if
constexpr
(
NDimSpatial
==
2
)
{
image_offsets
=
{
y_img_offset
,
x_img_offset
};
effs
=
{
y_eff
,
x_eff
};
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
image_offsets
=
{
z_img_offset
,
y_img_offset
,
x_img_offset
};
effs
=
{
z_eff
,
y_eff
,
x_eff
};
}
std
::
array
<
index_t
,
NDimSpatial
>
independent_filters
;
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
independent_filters
[
i
]
=
GetNumberOfIndependentFilters
(
input_spatial_lengths
[
i
],
input_left_pads
[
i
],
input_right_pads
[
i
],
filter_spatial_lengths
[
i
],
conv_filter_strides
[
i
],
conv_filter_dilations
[
i
],
image_offsets
[
i
]);
}
const
index_t
independent_filters_acum
=
ck
::
accumulate_n
<
index_t
>
(
independent_filters
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
independent_filters_acum
<=
0
)
continue
;
const
auto
in_grid_desc_m_k
=
MakeInputDescriptor_M_K
(
N
,
C
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
gemm_g_m_k_strides
,
independent_filters
,
effs
);
const
auto
out_grid_desc_m_k
=
MakeOutDescriptor_M_K
(
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
image_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
image_offsets
,
independent_filters
,
effs
);
in_grid_desc_m_k_container_
.
push_back
(
in_grid_desc_m_k
);
out_grid_desc_m_k_container_
.
push_back
(
out_grid_desc_m_k
);
const
index_t
x_idx
=
x_img_offset
/
conv_filter_strides
[
XIdx
];
const
index_t
y_idx
=
y_img_offset
/
conv_filter_strides
[
YIdx
];
const
index_t
z_idx
=
z_img_offset
/
conv_filter_strides
[
ZIdx
];
const
index_t
x_offset_with_pad
=
math
::
max
(
0
,
x_img_offset
-
input_left_pads
[
XIdx
]);
const
index_t
y_offset_with_pad
=
math
::
max
(
0
,
y_img_offset
-
input_left_pads
[
YIdx
]);
const
index_t
z_offset_with_pad
=
math
::
max
(
0
,
z_img_offset
-
input_left_pads
[
ZIdx
]);
// Memory offsets to next set of independent filters,
// move to independent filters in each dimension
const
index_t
in_offset
=
(
x_idx
+
y_idx
*
output_spatial_lengths
[
XIdx
]
+
z_idx
*
output_spatial_lengths
[
YIdx
]
*
output_spatial_lengths
[
XIdx
])
*
gemm_g_m_k_strides
[
I1
];
// Move to independent filters in appropriate dimensions
const
index_t
out_offset
=
x_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
XIdx
]
+
y_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
YIdx
]
+
z_offset_with_pad
*
image_g_n_c_wis_strides
[
spatial_offset
+
ZIdx
];
const
InputDataType
*
p_in_with_offset
=
static_cast
<
const
InputDataType
*>
(
p_in
)
+
in_offset
;
OutputDataType
*
p_out_with_offset
=
static_cast
<
OutputDataType
*>
(
p_out
)
+
out_offset
;
p_in_container_
.
push_back
(
p_in_with_offset
);
p_out_container_
.
push_back
(
p_out_with_offset
);
}
}
}
}
void
Print
()
const
{
for
(
std
::
size_t
i
=
0
;
i
<
in_grid_desc_m_k_container_
.
size
();
i
++
)
{
std
::
cout
<<
in_grid_desc_m_k_container_
[
i
]
<<
std
::
endl
;
std
::
cout
<<
out_grid_desc_m_k_container_
[
i
]
<<
std
::
endl
;
}
}
const
ck
::
index_t
G_
;
const
ck
::
index_t
C_
;
const
ck
::
index_t
X_
;
const
InputDataType
*
p_in_
;
OutputDataType
*
p_out_
;
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads_
;
std
::
vector
<
InputGridDesc
>
in_grid_desc_m_k_container_
;
std
::
vector
<
OutputGridDesc
>
out_grid_desc_m_k_container_
;
std
::
vector
<
const
InputDataType
*>
p_in_container_
;
std
::
vector
<
OutputDataType
*>
p_out_container_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
float
elapsed_time
=
0.
f
;
const
auto
kernel
=
kernel_tensor_rearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>
,
GridwiseTensorRearrangeKernel
>
;
// Execute each set of independent filters
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
in_grid_desc_m_k_container_
.
size
();
i
++
)
{
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
InputGridDesc
>
(
arg
.
out_grid_desc_m_k_container_
[
i
]);
const
index_t
grid_size
=
block_2_tile_map
.
CalculateGridSize
(
arg
.
in_grid_desc_m_k_container_
[
i
])
*
arg
.
G_
;
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_m_k_container_
[
i
],
arg
.
p_in_container_
[
i
],
arg
.
out_grid_desc_m_k_container_
[
i
],
arg
.
p_out_container_
[
i
],
arg
.
G_
,
block_2_tile_map
,
arg
.
compute_ptr_offset_of_batch_
);
}
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
using
namespace
tensor_layout
::
convolution
;
if
constexpr
(
!
(
is_NSpatialGC
||
is_GNSpatialC
))
{
return
false
;
}
const
auto
w_pad_left
=
arg
.
input_left_pads_
[
NDimSpatial
-
I1
];
const
auto
w_pad_right
=
arg
.
input_right_pads_
[
NDimSpatial
-
I1
];
const
auto
dilation_x
=
arg
.
conv_filter_dilations_
[
NDimSpatial
-
I1
];
const
auto
stride_x
=
arg
.
conv_filter_strides_
[
NDimSpatial
-
I1
];
bool
is_w_packed
=
arg
.
image_g_n_c_wis_strides_
[
NDimSpatial
+
I2
]
==
arg
.
C_
;
bool
is_c_packed
=
arg
.
image_g_n_c_wis_strides_
[
I2
]
==
1
;
// check vector acces with c not packed
if
(
!
is_c_packed
&&
ScalarPerVector
!=
1
)
return
false
;
// check vector access of filter window row (only C if C is not packed)
if
(
!
is_w_packed
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of filter window row (X * C)
if
(
arg
.
X_
*
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of pads (w_pad_left/w_pad_right * C)
if
(
w_pad_left
*
arg
.
C_
%
ScalarPerVector
!=
0
||
w_pad_right
*
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of with stride and pad
if
((
w_pad_left
!=
0
||
w_pad_right
!=
0
)
&&
stride_x
>
1
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
// check vector access of with dilation
if
(
dilation_x
>
1
&&
arg
.
C_
%
ScalarPerVector
!=
0
)
return
false
;
bool
valid
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
in_grid_desc_m_k_container_
.
size
();
i
++
)
{
valid
&=
GridwiseTensorRearrangeKernel
::
CheckValidity
(
arg
.
in_grid_desc_m_k_container_
[
i
],
arg
.
out_grid_desc_m_k_container_
[
i
]);
}
return
valid
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
return
Argument
{
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
// input image
void
*
p_out
,
// output image
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
image_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
3
>&
gemm_g_m_k_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_in
),
static_cast
<
OutputDataType
*>
(
p_out
),
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceColumnToImage"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
ScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
2724c519
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AsGridDesc_AK0_M_AK1
,
typename
BsGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_contraction_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1
,
const
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
as_grid_desc_ak0_m_ak1
;
ignore
=
bs_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
AsDataType
,
typename
BsDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceContractionMultipleABD_Xdl_CShuffle
:
public
DeviceContractionMultipleABD
<
NumDimM
,
NumDimN
,
NumDimK
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceContractionMultipleABD_Xdl_CShuffle
;
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ComputeDataType
=
EDataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleABD_xdl_cshuffle
<
AsDataType
,
BsDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_ms_ks_lengths_
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides_
)
{
assert
(
a_ms_ks_lengths_
.
size
()
==
NumDimM
+
NumDimK
&&
a_ms_ks_strides_
.
size
()
==
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_ms_ks_lengths_
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_M_K
(
as_ms_ks_lengths
[
i
],
as_ms_ks_strides
[
i
]);
},
Number
<
NumATensor
>
{});
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths_
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides_
)
{
assert
(
b_ns_ks_lengths_
.
size
()
==
NumDimN
+
NumDimK
&&
b_ns_ks_strides_
.
size
()
==
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_ns_ks_lengths_
,
Number
<
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_ns_ks_strides_
,
Number
<
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_N_K
(
bs_ns_ks_lengths
[
i
],
bs_ns_ks_strides
[
i
]);
},
Number
<
NumBTensor
>
{});
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_ms_ns_lengths_
,
const
std
::
vector
<
index_t
>&
e_ms_ns_strides_
)
{
assert
(
e_ms_ns_lengths_
.
size
()
==
NumDimM
+
NumDimN
&&
e_ms_ns_strides_
.
size
()
==
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_ms_ns_lengths_
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_ms_ns_strides_
,
Number
<
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AsGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAsGridDescriptor_M_K
({},
{}))
>
;
using
BsGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBsGridDescriptor_N_K
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
({},
{}))
>
;
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
as_grid_desc_m_k_
{},
bs_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
MakeEGridDescriptor_M_N
(
e_ms_ns_length
,
e_ms_ns_stride
)},
as_grid_desc_ak0_m_ak1_
{},
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
// populate pointer, desc for As
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using
ADataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
// A pointer
p_as_grid_
(
i
)
=
static_cast
<
const
ADataType
*>
(
p_as_grid
[
i
]);
// A desc
as_grid_desc_m_k_
(
i
)
=
MakeAGridDescriptor_M_K
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
});
// populate pointer, desc for Bs
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using
BDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
// B pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
BDataType
*>
(
p_bs_grid
[
i
]);
// B desc
bs_grid_desc_n_k_
(
i
)
=
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
});
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
MakeEGridDescriptor_M_N
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_m_k_
,
bs_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
as_mz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
]
==
1
;
as_kz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
]
==
1
;
as_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
bs_nz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
]
==
1
;
bs_kz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
]
==
1
;
bs_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_consecutive_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
}
e_nz_consecutive_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
}
// pointers
typename
GridwiseGemm
::
AsGridPointer
p_as_grid_
;
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AsGridDesc_M_K
as_grid_desc_m_k_
;
BsGridDesc_N_K
bs_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1_
;
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
std
::
array
<
bool
,
NumATensor
>
as_mz_consecutive_
;
std
::
array
<
bool
,
NumATensor
>
as_kz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_nz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
std
::
array
<
index_t
,
NumATensor
>
as_max_read_elems_
;
std
::
array
<
index_t
,
NumBTensor
>
bs_max_read_elems_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_max_read_elems_
;
index_t
e_max_write_elems_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
typename
GridwiseGemm
::
AsGridPointer
,
typename
GridwiseGemm
::
BsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AsGridDesc_AK0_M_AK1
,
DeviceOp
::
BsGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
as_grid_desc_ak0_m_ak1_
,
arg
.
bs_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
as_grid_desc_m_k_
[
I0
].
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
{
bool
valid_as_access
=
true
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
const
bool
valid_a_vector_size
=
arg
.
as_max_read_elems_
[
i
]
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_mz_consecutive_
[
i
];
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_kz_consecutive_
[
i
];
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
{
valid_as_access
=
false
;
}
});
if
(
!
valid_as_access
)
{
return
false
;
}
bool
valid_bs_access
=
true
;
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
const
bool
valid_b_vector_size
=
arg
.
bs_max_read_elems_
[
i
]
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_nz_consecutive_
[
i
];
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_kz_consecutive_
[
i
];
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
{
valid_bs_access
=
false
;
}
});
if
(
!
valid_bs_access
)
{
return
false
;
}
bool
valid_ds_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
const
bool
valid_d_vector_size
=
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
valid_ds_access
=
false
;
}
});
if
(
!
valid_ds_access
)
{
return
false
;
}
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
a_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
b_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
d_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
d_ms_ns_lengths
,
d_ms_ns_strides
,
e_ms_ns_length
,
e_ms_ns_stride
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumATensor
>&
as_ms_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumBTensor
>&
bs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_ms_ns_length
,
const
std
::
vector
<
index_t
>&
e_ms_ns_stride
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
p_ds
,
p_e
,
as_ms_ks_lengths
,
as_ms_ks_strides
,
bs_ns_ks_lengths
,
bs_ns_ks_strides
,
ds_ms_ns_lengths
,
ds_ms_ns_strides
,
e_ms_ns_length
,
e_ms_ns_stride
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceContractionMultipleABD_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
2724c519
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -53,7 +54,7 @@ __global__ void
...
@@ -53,7 +54,7 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -145,7 +146,8 @@ template <index_t NumDimM,
...
@@ -145,7 +146,8 @@ template <index_t NumDimM,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ComputeDataType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceContractionMultipleD_Xdl_CShuffle
struct
DeviceContractionMultipleD_Xdl_CShuffle
:
public
DeviceContractionMultipleD
<
NumDimM
,
:
public
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -156,7 +158,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -156,7 +158,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
,
ComputeDataType
>
{
{
using
DeviceOp
=
DeviceContractionMultipleD_Xdl_CShuffle
;
using
DeviceOp
=
DeviceContractionMultipleD_Xdl_CShuffle
;
...
@@ -181,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -181,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
};
const
auto
a_ms_
n
s_lengths
=
to_tuple
(
a_ms_ks_lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_
k
s_lengths
=
to_tuple
(
a_ms_ks_lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
// dimension Ids for M0, M1, ...
...
@@ -192,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -192,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_
n
s_lengths
,
mDimIds
);
const
auto
mLengths
=
get_container_subset
(
a_ms_
k
s_lengths
,
mDimIds
);
// lengths for K0, K1, ...
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_
n
s_lengths
,
kDimIds
);
const
auto
kLengths
=
get_container_subset
(
a_ms_
k
s_lengths
,
kDimIds
);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_
n
s_lengths
,
a_ms_ks_strides
);
make_naive_tensor_descriptor
(
a_ms_
k
s_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
...
@@ -313,6 +316,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -313,6 +316,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
ComputeDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
DsDataType
,
DsDataType
,
...
@@ -355,14 +360,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -355,14 +360,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
LoopSched
>
;
LoopSched
>
;
// desc for blockwise copy
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
using
AGridDesc_AK0_M_AK1
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
AGridDesc_M_K
{}))
>
;
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
// block-to-e-tile map
using
Block2ETileMap
=
using
Block2ETileMap
=
...
@@ -375,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -375,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
void
*
p_e_grid
,
const
std
::
vector
<
index_t
>&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
...
@@ -390,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -390,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_ms_
n
s_lengths
,
a_ms_ks_strides
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_ms_
k
s_lengths
,
a_ms_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
,
b_ns_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
,
b_ns_ks_strides
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_ms_ns_lengths
,
e_ms_ns_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_ms_ns_lengths
,
e_ms_ns_strides
)},
...
@@ -403,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -403,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
}
a_mz_stride_
{},
a_kz_stride_
{},
b_nz_stride_
{},
b_kz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{}
{
{
// populate pointer, batch stride, desc for Ds
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -440,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -440,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
}
// for sanity check of vector memory access
// for sanity check of vector memory access
a_mz_stride_
=
a_ms_ks_strides
[
NumDimM
-
1
];
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
a_kz_stride_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
];
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
a_max_read_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
b_nz_stride_
=
b_ns_ks_strides
[
NumDimN
-
1
];
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
b_kz_stride_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
];
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
b_max_read_elems_
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
{
ds_nz_stride_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
];
ds_nz_consecutive_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
}
}
e_nz_stride_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
];
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
}
}
void
Print
()
const
void
Print
()
const
...
@@ -491,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -491,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store
// in the memory or not.
index_t
a_mz_stride_
;
bool
a_mz_consecutive_
;
index_t
a_kz_stride_
;
bool
a_kz_consecutive_
;
index_t
b_nz_stride_
;
bool
b_nz_consecutive_
;
index_t
b_kz_stride_
;
bool
b_kz_consecutive_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
index_t
e_mz_stride_
;
bool
e_nz_consecutive_
;
index_t
e_nz_stride_
;
index_t
a_max_read_elems_
;
index_t
b_max_read_elems_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_max_read_elems_
;
index_t
e_max_write_elems_
;
};
};
// Invoker
// Invoker
...
@@ -582,13 +597,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -582,13 +597,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
ck
::
is_xdl_supported
())
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
ck
::
get_device_name
()
!=
"gfx940"
&&
ck
::
get_device_name
()
!=
"gfx941"
&&
ck
::
get_device_name
()
!=
"gfx942"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
{
return
false
;
return
false
;
}
}
...
@@ -607,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -607,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
"wrong!"
);
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
const
bool
valid_a_vector_size
=
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
{
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
return
false
;
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
// vector memory access of B: could be on N or BK1 dimension
const
bool
valid_b_vector_size
=
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
{
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
{
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
return
false
;
}
}
else
{
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
return
false
;
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_ds_access
=
true
;
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
const
bool
valid_d_vector_size
=
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
CDEBlockTransferScalarPerVector_NPerBlock
==
// Vector read of Ds is always on N dimension.
0
))
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
{
valid_d_access
=
false
;
valid_d
s
_access
=
false
;
}
}
});
});
if
(
!
valid_ds_access
)
if
(
valid_d_access
==
false
)
{
{
return
false
;
return
false
;
}
}
// vector memory access of E: always on NPerBlock dimension
const
bool
valid_e_vector_size
=
if
(
!
(
arg
.
e_
nz_stride_
==
1
&&
arg
.
e_
max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
// Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock
==
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
0
))
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
{
return
false
;
return
false
;
}
}
...
@@ -683,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -683,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
...
@@ -699,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -699,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b
,
p_b
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_ms_
n
s_lengths
,
a_ms_
k
s_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
b_ns_ks_strides
,
...
@@ -720,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -720,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_ms_
n
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides
,
...
@@ -736,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -736,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b
,
p_b
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_ms_
n
s_lengths
,
a_ms_
k
s_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
b_ns_ks_strides
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
0 → 100644
View file @
2724c519
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <sstream>
#include <vector>
#include "ck/ck.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template
<
index_t
NumDim1
,
index_t
NumDim2
>
auto
CalculateMaxRead
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
if
(
lengths
.
size
()
!=
NumDim1
+
NumDim2
)
{
std
::
ostringstream
err
;
err
<<
"Incorrect number of lengths in "
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
strides
.
size
()
!=
NumDim1
+
NumDim2
)
{
std
::
ostringstream
err
;
err
<<
"Incorrect number of strides in "
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
// Determine the beginning and end idx of the group representing the FCD.
index_t
begin_idx
,
end_idx
;
if
(
strides
[
NumDim1
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
}
else
if
(
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return
1
;
}
index_t
consecutive_stride
=
1
;
for
(
index_t
dim_idx
=
end_idx
;
dim_idx
>=
begin_idx
;
--
dim_idx
)
{
if
(
strides
[
dim_idx
]
==
consecutive_stride
)
{
consecutive_stride
*=
lengths
[
dim_idx
];
}
else
{
break
;
}
}
const
index_t
max_subsequent_elems
=
consecutive_stride
;
return
max_subsequent_elems
;
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
2724c519
...
@@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
hipGetErrorString
(
hipMemset
Async
(
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
0
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
ave_time
=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -649,6 +650,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -649,6 +650,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// vector load A/B matrix from global memory
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
2
&&
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
2
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
2724c519
...
@@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
2724c519
...
@@ -810,6 +810,11 @@ struct
...
@@ -810,6 +810,11 @@ struct
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
2724c519
...
@@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
2724c519
...
@@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
2724c519
...
@@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
2724c519
...
@@ -56,7 +56,7 @@ __global__ void
...
@@ -56,7 +56,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0
__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
num_batches
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
num_batches
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
...
@@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
2724c519
...
@@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
...
@@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_navi2_supported
()
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
is_navi3_supported
()))
ck
::
get_device_name
()
==
"gfx1102"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
2724c519
...
@@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
...
...
Prev
1
…
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment