Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a0f2c35
Commit
1a0f2c35
authored
Oct 12, 2023
by
Bartlomiej Kocot
Committed by
Bartłomiej Kocot
Oct 12, 2023
Browse files
Add grouped conv bwd weight wmma
parent
f3b02ecf
Changes
67
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1379 additions
and
130 deletions
+1379
-130
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+11
-2
example/20_grouped_conv_bwd_weight/common.hpp
example/20_grouped_conv_bwd_weight/common.hpp
+18
-22
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+19
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
...ped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
+87
-0
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+19
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+19
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
..._weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
+19
-1
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-21
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+1
-26
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+884
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+7
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+8
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
...bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
+0
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp
...d_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp
+122
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+126
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt
+9
-9
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
...ped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
+0
-0
No files found.
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
1a0f2c35
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list
_xdl
AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
...
...
@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endif
()
set
(
target 1
)
endif
()
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
...
...
example/20_grouped_conv_bwd_weight/common.hpp
View file @
1a0f2c35
...
...
@@ -46,25 +46,21 @@ struct CommonLayoutSetting
using
OutputLayout
=
OutputLay
;
};
template
<
ck
::
index_t
NDimSpatial
>
struct
CommonLayoutSettingSelector
;
namespace
ctl
=
ck
::
tensor_layout
::
convolution
;
template
<
>
struct
CommonLayoutSettingSelector
<
1
>
final
:
CommonLayoutSetting
<
ctl
::
GNWC
,
ctl
::
GKXC
,
ctl
::
GNWK
>
{
};
template
<
>
struct
CommonLayoutSettingSelector
<
2
>
final
:
CommonLayoutSetting
<
ctl
::
GNHWC
,
ctl
::
GKYXC
,
ctl
::
GNHWK
>
{
};
template
<
>
struct
CommonLayoutSettingSelector
<
3
>
final
:
CommonLayoutSetting
<
ctl
::
GNDHWC
,
ctl
::
GKZYXC
,
ctl
::
GNDHWK
>
template
<
ck
::
index_t
NDimSpatial
>
struct
CommonLayoutSettingSelector
:
CommonLayoutSetting
<
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>>
{
};
...
...
@@ -84,10 +80,10 @@ struct ExecutionConfig final
bool
time_kernel
=
false
;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2
, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \
#define DefaultConvParam
\
ck::utils::conv::ConvParam
\
{
\
3
, 4, 1, 128, 256, {3,
3,
3}, {14,
14,
14}, {1,
1,
1}, {1,
1,
1}, {1,
1,
1}, { 1,
1,
1 } \
}
inline
void
print_help_msg
()
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
1a0f2c35
...
...
@@ -76,4 +76,22 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
0 → 100644
View file @
1a0f2c35
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
OutDataType
=
F16
;
using
AccDataType
=
F32
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Wmma_CShuffle
<
NDimSpatial
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
,
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
16
,
// MPerWMMA
16
,
// NPerWMMA
4
,
// MRepeat
2
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
0
,
2
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
true
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
0
,
2
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
true
,
// BBlockLdsExtraN
4
,
2
,
S
<
1
,
32
,
1
,
8
>
,
1
>
;
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
1a0f2c35
...
...
@@ -78,4 +78,22 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
1a0f2c35
...
...
@@ -77,4 +77,22 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
View file @
1a0f2c35
...
...
@@ -83,4 +83,22 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
1
;
}
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
1a0f2c35
...
...
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
// Dl op doesn't support split_k > 1
// Dl
and WMMA
op
s
doesn't support split_k > 1
constexpr
ck
::
index_t
split_k
=
1
;
const
auto
in_g_n_c_wis_desc
=
...
...
@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return
true
;
}
bool
run_grouped_conv_bwd_weight_example
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
false
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
false
;
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
1a0f2c35
...
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
ADataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
1a0f2c35
...
...
@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
...
@@ -22,32 +23,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
0 → 100644
View file @
1a0f2c35
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
1a0f2c35
...
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -21,32 +22,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
...
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
1a0f2c35
...
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
1a0f2c35
...
...
@@ -48,10 +48,18 @@ struct ComputePtrOffsetOfStridedBatch
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
template
<
typename
enable_if
<
NumDTensor
==
0
,
bool
>
::
type
=
false
>
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
1a0f2c35
...
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
...
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
...
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
View file @
1a0f2c35
...
...
@@ -6,8 +6,6 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp
0 → 100644
View file @
1a0f2c35
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
1a0f2c35
...
...
@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
...
...
@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
...
...
@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
...
...
@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
BF8
,
F8
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef DL_KERNELS
// dl
...
...
@@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
...
...
@@ -708,6 +810,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
...
...
@@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
...
...
@@ -752,6 +868,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances
(
op_ptrs
);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
...
...
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt
View file @
1a0f2c35
set
(
GROUPED_CONV1D_BWD_WEIGHT
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp
)
xdl/
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
xdl/
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
xdl/
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV1D_BWD_WEIGHT
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp
)
dl/
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp
dl/
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp
dl/
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
dl/
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp
dl/
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp
dl/
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp
)
endif
()
add_instance_library
(
device_grouped_conv1d_bwd_weight_instance
${
GROUPED_CONV1D_BWD_WEIGHT
}
)
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/
dl/
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
View file @
1a0f2c35
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment