Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
581d244c
Commit
581d244c
authored
Jan 27, 2023
by
Rosty Geyyer
Browse files
Add gridwise gemm supporting batched input
parent
a768dea5
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
632 additions
and
84 deletions
+632
-84
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+14
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+74
-70
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+544
-0
No files found.
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
581d244c
...
@@ -56,20 +56,20 @@ using DeviceConvBwdWeightInstance =
...
@@ -56,20 +56,20 @@ using DeviceConvBwdWeightInstance =
1
,
// KPerThread
1
,
// KPerThread
S
<
8
,
2
>
,
// M1N1ThreadClusterM1Xs
S
<
8
,
2
>
,
// M1N1ThreadClusterM1Xs
S
<
8
,
2
>
,
// M1N1ThreadClusterN1Xs
S
<
8
,
2
>
,
// M1N1ThreadClusterN1Xs
S
<
8
,
1
,
1
,
2
>
,
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S
<
1
,
8
,
1
,
1
,
2
>
,
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S
<
2
,
1
,
128
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S
<
1
,
2
,
1
,
128
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S
<
1
,
2
,
0
,
3
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
2
,
0
,
3
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcAccessOrder
S
<
4
,
1
,
1
,
2
>
,
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S
<
1
,
4
,
1
,
1
,
2
>
,
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S
<
1
,
2
,
0
,
3
>
,
// ABlockTransferSrcVectorTensorContiguousDimOrder
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
1
,
1
,
2
>
,
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S
<
1
,
1
,
1
,
1
,
2
>
,
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S
<
1
,
1
,
8
,
2
>
,
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S
<
1
,
1
,
1
,
8
,
2
>
,
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S
<
16
,
1
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S
<
1
,
16
,
1
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
1
,
8
,
1
>
,
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S
<
1
,
1
,
1
,
8
,
1
>
,
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferSrcVectorTensorContiguousDimOrder
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
1
,
1
,
2
>
,
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S
<
1
,
1
,
1
,
1
,
2
>
,
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
4
>
;
// CThreadTransferDstScalarPerVector
4
>
;
// CThreadTransferDstScalarPerVector
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
581d244c
...
@@ -50,10 +50,10 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -50,10 +50,10 @@ struct ComputePtrOffsetOfStridedBatch
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
AGridDesc_
B_
K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
BGridDesc_
B_
K0_N0_N1_K1
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
Default
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
...
@@ -66,10 +66,10 @@ __global__ void
...
@@ -66,10 +66,10 @@ __global__ void
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1
,
const
AGridDesc_
B_
K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1
,
const
BGridDesc_
B_
K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Default
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
...
@@ -85,7 +85,7 @@ __global__ void
...
@@ -85,7 +85,7 @@ __global__ void
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
...
@@ -729,55 +729,55 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -729,55 +729,55 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_
B_
K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc_
B_
K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
GridwiseGemmDl_
b
km_
b
kn_mn_v1r3
<
BlockSize
,
ADataType
,
ADataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_
B_
K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_
B_
K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
K1
,
K1
,
M1PerThread
,
M1PerThread
,
N1PerThread
,
N1PerThread
,
KPerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
CThreadTransferDstScalarPerVector
>
;
// Argument
// Argument
using
AGridDesc_K0_M0_M1_K1
=
using
AGridDesc_
B_
K0_M0_M1_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_
B_
K0_M0_M1_K1
(
AGridDesc_
B_
K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
using
BGridDesc_
B_
K0_N0_N1_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_
B_
K0_N0_N1_K1
(
BGridDesc_
B_
K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Default
Block2CTileMap
=
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
Make
DefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
Make
CBlockClusterAdaptor
(
CGridDesc_M_N
{}
,
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -842,12 +842,15 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -842,12 +842,15 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
c_grid_desc_m_n_
=
descs
[
I2
];
c_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_kbatch_k0_m0_m1_k1_
=
a_grid_desc_kbatch_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1_
);
GridwiseGemm
::
MakeAGridDescriptor_
B_
K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1_
);
b_grid_desc_kbatch_k0_n0_n1_k1_
=
b_grid_desc_kbatch_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_kbatch_k0_n_k1_
);
GridwiseGemm
::
MakeBGridDescriptor_
B_
K0_N0_N1_K1
(
b_grid_desc_kbatch_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
ck
::
index_t
M01
=
1
;
ck
::
index_t
N01
=
1
;
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
...
@@ -874,15 +877,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -874,15 +877,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
AGridDesc_
B_
K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
BGridDesc_
B_
K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1_
;
AGridDesc_
B_
K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1_
;
BGridDesc_
B_
K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
...
@@ -941,7 +945,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -941,7 +945,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
arg
.
c_grid_desc_m_n_
))
arg
.
c_grid_desc_m_n_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm GridwiseGemmDl_km_kn_mn_v1r3 has invalid setting"
);
"wrong! GridwiseGemm GridwiseGemmDl_
b
km_
b
kn_mn_v1r3 has invalid setting"
);
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
...
@@ -950,16 +954,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -950,16 +954,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
auto
has_double_tail_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
.
value
;
const
auto
kernel
=
kernel_batched_gemm_dlops_bwd_weight
<
const
auto
kernel
=
kernel_batched_gemm_dlops_bwd_weight
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_
B_
K0_M0_M1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_
B_
K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Default
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
@@ -1045,18 +1049,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1045,18 +1049,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// matrix A
// matrix A
{
{
auto
srcVectorLengths
=
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
{};
auto
srcVectorLengths
=
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
{};
if
(
srcVectorLengths
[
I
1
]
!=
1
||
srcVectorLengths
[
I
2
]
!=
1
)
if
(
srcVectorLengths
[
I
2
]
!=
1
||
srcVectorLengths
[
I
3
]
!=
1
)
{
{
return
false
;
return
false
;
}
}
if
(
K1
%
srcVectorLengths
[
I
3
]
!=
0
||
K0PerBlock
%
srcVectorLengths
[
I
0
]
!=
0
)
if
(
K1
%
srcVectorLengths
[
I
4
]
!=
0
||
K0PerBlock
%
srcVectorLengths
[
I
1
]
!=
0
)
{
{
return
false
;
return
false
;
}
}
const
index_t
K
=
arg
.
Conv_K_
;
const
index_t
K
=
arg
.
Conv_K_
;
if
(
K
%
(
srcVectorLengths
[
I
0
]
*
srcVectorLengths
[
I
3
])
!=
0
)
if
(
K
%
(
srcVectorLengths
[
I
1
]
*
srcVectorLengths
[
I
4
])
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -1066,19 +1070,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1066,19 +1070,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
{
auto
srcLoadLenghts
=
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
{};
auto
srcLoadLenghts
=
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
{};
auto
srcVectorLengths
=
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
{};
auto
srcVectorLengths
=
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
{};
if
(
srcVectorLengths
[
I
0
]
!=
1
||
srcVectorLengths
[
I
3
]
!=
1
)
if
(
srcVectorLengths
[
I
1
]
!=
1
||
srcVectorLengths
[
I
4
]
!=
1
)
{
{
return
false
;
return
false
;
}
}
if
(
srcLoadLenghts
[
I
1
]
%
srcVectorLengths
[
I
1
]
!=
0
||
if
(
srcLoadLenghts
[
I
2
]
%
srcVectorLengths
[
I
2
]
!=
0
||
srcLoadLenghts
[
I
2
]
%
srcVectorLengths
[
I
2
]
!=
0
)
srcLoadLenghts
[
I
3
]
%
srcVectorLengths
[
I
3
]
!=
0
)
{
{
return
false
;
return
false
;
}
}
const
index_t
C
=
arg
.
Conv_K_
;
const
index_t
C
=
arg
.
Conv_K_
;
if
(
C
%
(
srcVectorLengths
[
I
1
]
*
srcVectorLengths
[
I
2
])
!=
0
)
if
(
C
%
(
srcVectorLengths
[
I
2
]
*
srcVectorLengths
[
I
3
])
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
581d244c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment