Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7830272f
Commit
7830272f
authored
Oct 17, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
2b8a9941
16d7c4d2
Changes
71
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1382 additions
and
131 deletions
+1382
-131
CHANGELOG.md
CHANGELOG.md
+1
-1
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+11
-2
example/20_grouped_conv_bwd_weight/common.hpp
example/20_grouped_conv_bwd_weight/common.hpp
+18
-22
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+20
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
...ped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
+88
-0
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+20
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+20
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
..._weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
+20
-1
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-21
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+877
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+7
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+8
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+14
-6
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
...bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
+0
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp
...d_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp
+119
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+126
-0
No files found.
CHANGELOG.md
View file @
7830272f
...
@@ -14,7 +14,7 @@ None
...
@@ -14,7 +14,7 @@ None
### Additions
### Additions
-
Added an image to a column kernel (#867)
-
Added an image to a column kernel (#867)
-
Added a column to an image kernel (#930)
-
Added a column to an image kernel (#930)
-
Support for 3D grouped convolution
forward
on RDNA 3 GPUs (#935)
-
Support for 3D grouped convolution on RDNA 3 GPUs (#935
, #950, #985
)
-
Grouped convolution support for small K and C (#822 #879 #897)
-
Grouped convolution support for small K and C (#822 #879 #897)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799)
-
Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799)
...
...
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
7830272f
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list
_xdl
AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
...
@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endif
()
endif
()
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16
)
endif
()
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
...
...
example/20_grouped_conv_bwd_weight/common.hpp
View file @
7830272f
...
@@ -46,25 +46,21 @@ struct CommonLayoutSetting
...
@@ -46,25 +46,21 @@ struct CommonLayoutSetting
using
OutputLayout
=
OutputLay
;
using
OutputLayout
=
OutputLay
;
};
};
template
<
ck
::
index_t
NDimSpatial
>
struct
CommonLayoutSettingSelector
;
namespace
ctl
=
ck
::
tensor_layout
::
convolution
;
namespace
ctl
=
ck
::
tensor_layout
::
convolution
;
template
<
ck
::
index_t
NDimSpatial
>
template
<
>
struct
CommonLayoutSettingSelector
struct
CommonLayoutSettingSelector
<
1
>
final
:
CommonLayoutSetting
<
ctl
::
GNWC
,
ctl
::
GKXC
,
ctl
::
GNWK
>
:
CommonLayoutSetting
<
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
{
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
};
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
template
<
>
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
struct
CommonLayoutSettingSelector
<
2
>
final
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
:
CommonLayoutSetting
<
ctl
::
GNHWC
,
ctl
::
GKYXC
,
ctl
::
GNHWK
>
ck
::
tensor_layout
::
convolution
::
GKYXC
,
{
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
};
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
template
<
>
ck
::
tensor_layout
::
convolution
::
GNHWK
,
struct
CommonLayoutSettingSelector
<
3
>
final
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>>
:
CommonLayoutSetting
<
ctl
::
GNDHWC
,
ctl
::
GKZYXC
,
ctl
::
GNDHWK
>
{
{
};
};
...
@@ -84,10 +80,10 @@ struct ExecutionConfig final
...
@@ -84,10 +80,10 @@ struct ExecutionConfig final
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
#define DefaultConvParam \
#define DefaultConvParam
\
ck::utils::conv::ConvParam \
ck::utils::conv::ConvParam
\
{ \
{
\
2
, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \
3
, 4, 1, 128, 256, {3,
3,
3}, {14,
14,
14}, {1,
1,
1}, {1,
1,
1}, {1,
1,
1}, { 1,
1,
1 } \
}
}
inline
void
print_help_msg
()
inline
void
print_help_msg
()
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
7830272f
...
@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
0 → 100644
View file @
7830272f
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
OutDataType
=
F16
;
using
AccDataType
=
F32
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Wmma_CShuffle
<
NDimSpatial
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
,
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
16
,
// MPerWMMA
16
,
// NPerWMMA
4
,
// MRepeat
2
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
0
,
2
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
true
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
0
,
2
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
true
,
// BBlockLdsExtraN
4
,
2
,
S
<
1
,
32
,
1
,
8
>
,
1
>
;
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
7830272f
...
@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
7830272f
...
@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
View file @
7830272f
...
@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
7830272f
...
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
...
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
{
// Dl op do
es
n't support split_k > 1
// Dl
and WMMA
op
s
don't support split_k > 1
constexpr
ck
::
index_t
split_k
=
1
;
constexpr
ck
::
index_t
split_k
=
1
;
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
...
@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return
true
;
return
true
;
}
}
bool
run_grouped_conv_bwd_weight_example
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
false
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
false
;
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
7830272f
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
7830272f
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
@@ -22,32 +23,6 @@ namespace ck {
...
@@ -22,32 +23,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
OutElementwiseOperation
a_element_op_
;
OutElementwiseOperation
a_element_op_
;
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
,
has_main_loop
,
has_double_loop
>
;
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
0 → 100644
View file @
7830272f
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
7830272f
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -21,32 +22,6 @@ namespace ck {
...
@@ -21,32 +22,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
if
constexpr
(
!
is_GNWK_GKXC_GNWC
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
7830272f
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_grouped_conv_multiple_d_wmma_cshuffle
<
GridwiseOp
,
GridwiseOp
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
7830272f
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return
ds_offset
;
return
ds_offset
;
}
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
}
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
...
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
7830272f
...
@@ -36,7 +36,7 @@ __global__ void
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
return
true
;
}
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
include/ck/utility/type_convert.hpp
View file @
7830272f
...
@@ -146,7 +146,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -146,7 +146,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -154,6 +154,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -154,6 +154,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -164,9 +166,11 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
...
@@ -164,9 +166,11 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
#endif
...
@@ -222,7 +226,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -222,7 +226,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
// convert to float and use native converion
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
@@ -230,6 +234,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
...
@@ -230,6 +234,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
...
@@ -240,9 +246,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
...
@@ -240,9 +246,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
#endif
...
@@ -354,7 +362,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -354,7 +362,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
#else
return
type
_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8
_convert
_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
#endif
...
@@ -406,7 +414,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -406,7 +414,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
#else
#else
return
type
_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8
_convert
_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
#endif
}
}
#endif
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp
View file @
7830272f
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp
0 → 100644
View file @
7830272f
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
7830272f
...
@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
...
@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
...
@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
...
@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
...
@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
...
@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
...
@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
...
@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
BF8
,
BF8
,
F8
>>>&
instances
);
F8
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef DL_KERNELS
#ifdef DL_KERNELS
// dl
// dl
...
@@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
...
@@ -708,6 +810,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -708,6 +810,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances
(
op_ptrs
);
}
#endif
#endif
}
}
else
if
constexpr
(
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
else
if
constexpr
(
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
...
@@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
...
@@ -752,6 +868,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -752,6 +868,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances
(
op_ptrs
);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment