Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3f976dd0
"...resnet50_tensorflow.git" did not exist on "a9d5da287f2d8ad25ab19aa1674f89b39d5a119d"
Commit
3f976dd0
authored
Jan 10, 2023
by
Rosty Geyyer
Browse files
Update batch handling
parent
b9f23971
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
233 additions
and
129 deletions
+233
-129
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+36
-37
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+197
-92
No files found.
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
3f976dd0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include <algorithm>
#include <iostream>
#include <iostream>
#include <numeric>
#include <iterator>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
...
@@ -37,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Dl
<
ck
::
tensor_operation
::
device
::
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Dl
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
3f976dd0
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <numeric>
#include <sstream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
@@ -20,6 +21,83 @@ namespace ck {
...
@@ -20,6 +21,83 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
DefaultBlock2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_dlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
index_t
batch_count
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
DefaultBlock2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc_kbatch_k0_m0_m1_k1
,
b_grid_desc_kbatch_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
...
@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial,
...
@@ -56,7 +134,7 @@ template <ck::index_t NDimSpatial,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Dl
struct
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Dl
:
public
DeviceGroupedConvBwdWeight
<
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
...
@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -78,7 +156,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
OutElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Dl
;
using
DeviceOp
=
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Dl
;
using
ADataType
=
OutDataType
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
BDataType
=
InDataType
;
...
@@ -116,8 +194,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -116,8 +194,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
...
@@ -268,8 +346,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -268,8 +346,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// function end
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
...
@@ -436,8 +514,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -436,8 +514,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// function end
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
...
@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -727,6 +805,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_grid_desc_kbatch_k0_m_k1_
{},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
c_element_op_
{
in_element_op
},
...
@@ -761,10 +841,33 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -761,10 +841,33 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
c_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_kbatch_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1_
);
a_grid_desc_kbatch_k0_m0_m1_k1_
=
b_grid_desc_kbatch_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_kbatch_k0_n_k1_
);
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
b_grid_desc_kbatch_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_kbatch_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
}
}
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
...
@@ -781,6 +884,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -781,6 +884,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
DefaultBlock2CTileMap
block_2_ctile_map_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
OutElementwiseOperation
a_element_op_
;
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
b_element_op_
;
...
@@ -813,20 +919,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -813,20 +919,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
...
@@ -850,7 +952,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -850,7 +952,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
;
const
auto
kernel
=
kernel_
gemm_dl_v1r3
<
const
auto
kernel
=
kernel_
batched_gemm_dlops_bwd_weight
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -858,6 +960,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -858,6 +960,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
@@ -869,10 +972,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -869,10 +972,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m0_m1_k1_
,
arg
.
a_grid_desc_kbatch_k0_m0_m1_k1_
,
arg
.
b_grid_desc_kbatch_k0_n0_n1_k1_
,
arg
.
b_grid_desc_kbatch_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
};
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m0_m1_k1_
.
GetLength
(
I1
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m0_m1_k1_
.
GetLength
(
I1
);
...
@@ -882,7 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -882,7 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
...
@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -987,9 +1093,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
arg
.
c_grid_desc_m_n_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -1088,7 +1193,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Dl"
str
<<
"Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Dl"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment