Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9b3365e1
Unverified
Commit
9b3365e1
authored
Nov 12, 2022
by
Po Yen Chen
Committed by
GitHub
Nov 12, 2022
Browse files
Merge branch 'develop' into gridwise_2d
parents
9608beee
b79bbbc2
Changes
187
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
617 additions
and
590 deletions
+617
-590
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
+1
-0
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp
+1
-0
example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc
...ouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc
+2
-2
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
+4
-4
example/CMakeLists.txt
example/CMakeLists.txt
+2
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp
...sor_operation/gpu/device/device_grouped_conv_bwd_data.hpp
+0
-49
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
...on/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
+1
-96
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+11
-10
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp
...k/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp
+21
-22
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+40
-38
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
...e_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
+225
-90
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+14
-3
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+8
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+4
-3
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
...or_operation_instance/gpu/convolution_backward_weight.hpp
+0
-230
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+44
-36
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+235
-0
No files found.
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
View file @
9b3365e1
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
...
...
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp
View file @
9b3365e1
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
...
...
example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc
View file @
9b3365e1
...
@@ -97,7 +97,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
...
@@ -97,7 +97,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input1_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input1_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input1_right_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input1_right_pads
{};
auto
copy
=
[](
auto
&
x
,
auto
&
y
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
()
,
y
.
begin
());
};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
ck
::
ranges
::
copy
(
x
,
y
.
begin
());
};
copy
(
in0_g_n_c_wis_desc
.
GetLengths
(),
a0_g_n_c_wis_lengths
);
copy
(
in0_g_n_c_wis_desc
.
GetLengths
(),
a0_g_n_c_wis_lengths
);
copy
(
in0_g_n_c_wis_desc
.
GetStrides
(),
a0_g_n_c_wis_strides
);
copy
(
in0_g_n_c_wis_desc
.
GetStrides
(),
a0_g_n_c_wis_strides
);
...
@@ -261,7 +261,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
...
@@ -261,7 +261,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification,
#endif
#endif
return
ck
::
utils
::
check_err
(
return
ck
::
utils
::
check_err
(
out1_device
.
mData
,
out1_host
.
mData
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out1_device
,
out1_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
}
}
return
true
;
return
true
;
...
...
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
View file @
9b3365e1
...
@@ -100,9 +100,9 @@ int main(int argc, char* argv[])
...
@@ -100,9 +100,9 @@ int main(int argc, char* argv[])
Tensor
<
GammaDataType
>
gamma
({
G
,
C
});
Tensor
<
GammaDataType
>
gamma
({
G
,
C
});
Tensor
<
BetaDataType
>
beta
({
G
,
C
});
Tensor
<
BetaDataType
>
beta
({
G
,
C
});
ck
::
utils
::
FillUniformDistribution
<
XDataType
>
{
0.
f
,
1.
f
}(
x
.
begin
(),
x
.
end
()
);
ck
::
utils
::
FillUniformDistribution
<
XDataType
>
{
0.
f
,
1.
f
}(
x
);
ck
::
utils
::
FillUniformDistribution
<
GammaDataType
>
{
0.
f
,
1.
f
}(
gamma
.
begin
(),
gamma
.
end
()
);
ck
::
utils
::
FillUniformDistribution
<
GammaDataType
>
{
0.
f
,
1.
f
}(
gamma
);
ck
::
utils
::
FillUniformDistribution
<
BetaDataType
>
{
0.
f
,
1.
f
}(
beta
.
begin
(),
beta
.
end
()
);
ck
::
utils
::
FillUniformDistribution
<
BetaDataType
>
{
0.
f
,
1.
f
}(
beta
);
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
...
@@ -167,7 +167,7 @@ int main(int argc, char* argv[])
...
@@ -167,7 +167,7 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
example/CMakeLists.txt
View file @
9b3365e1
...
@@ -12,6 +12,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -12,6 +12,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
function
(
add_example_executable_no_testing EXAMPLE_NAME FILE_NAME
)
function
(
add_example_executable_no_testing EXAMPLE_NAME FILE_NAME
)
...
@@ -19,6 +20,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -19,6 +20,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
endfunction
(
add_example_executable_no_testing EXAMPLE_NAME
)
endfunction
(
add_example_executable_no_testing EXAMPLE_NAME
)
# add all example subdir
# add all example subdir
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp
deleted
100644 → 0
View file @
9608beee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
index_t
NDimSpatial
,
typename
InputLayout
,
typename
WeightLayout
,
typename
OutputLayout
,
typename
InputDataType
,
typename
WeightDataType
,
typename
OutputDataType
,
typename
InputElementwiseOperation
,
typename
WeightElementwiseOperation
,
typename
OutputElementwiseOperation
>
struct
DeviceGroupedConvBwdData
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_input
,
const
void
*
p_weight
,
const
void
*
p_output
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weight_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weight_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
InputElementwiseOperation
&
input_element_op
,
const
WeightElementwiseOperation
&
weight_element_op
,
const
OutputElementwiseOperation
&
output_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp
View file @
9b3365e1
...
@@ -3,10 +3,9 @@
...
@@ -3,10 +3,9 @@
#pragma once
#pragma once
#include <
vector
>
#include <
array
>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -63,100 +62,6 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
...
@@ -63,100 +62,6 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedConvBwdDataMultipleD
<
NDimSpatial
,
ALayout
,
BLayout
,
Tuple
<>
,
ELayout
,
ADataType
,
BDataType
,
Tuple
<>
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
:
public
DeviceGroupedConvBwdData
<
NDimSpatial
,
ELayout
,
BLayout
,
ALayout
,
EDataType
,
BDataType
,
ADataType
,
CDEElementwiseOperation
,
BElementwiseOperation
,
AElementwiseOperation
>
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
// output image
const
void
*
p_b
,
// weight
const
std
::
array
<
const
void
*
,
0
>&
,
// bias
void
*
p_e
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
// weight
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
0
>&
,
// bias
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
0
>&
,
// bias
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_lengths
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_c_wis_strides
,
// input image
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_input
,
const
void
*
p_weight
,
const
void
*
p_output
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weight_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weight_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
CDEElementwiseOperation
&
input_element_op
,
const
BElementwiseOperation
&
weight_element_op
,
const
AElementwiseOperation
&
output_element_op
)
override
final
{
return
MakeArgumentPointer
(
p_output
,
p_weight
,
std
::
array
<
const
void
*
,
0
>
{},
p_input
,
output_g_n_k_wos_lengths
,
output_g_n_k_wos_strides
,
weight_g_k_c_xs_lengths
,
weight_g_k_c_xs_strides
,
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
0
>
{},
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
0
>
{},
input_g_n_c_wis_lengths
,
input_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
output_element_op
,
weight_element_op
,
input_element_op
);
}
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp
→
include/ck/tensor_operation/gpu/device/device_
grouped_
conv_bwd_weight.hpp
View file @
9b3365e1
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#pragma once
#pragma once
#include <
vector
>
#include <
array
>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
...
@@ -11,7 +11,7 @@ namespace ck {
...
@@ -11,7 +11,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
ck
::
index_t
N
um
DimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
InLayout
,
typename
WeiLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
OutLayout
,
...
@@ -21,22 +21,23 @@ template <ck::index_t NumDimSpatial,
...
@@ -21,22 +21,23 @@ template <ck::index_t NumDimSpatial,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
struct
DeviceConvBwdWeight
:
public
BaseOperator
struct
Device
Grouped
ConvBwdWeight
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
void
*
p_wei
,
const
void
*
p_out
,
const
void
*
p_out
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp
View file @
9b3365e1
...
@@ -14,39 +14,38 @@ namespace device {
...
@@ -14,39 +14,38 @@ namespace device {
// Convolution Forward:
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// E = cde_op(C, D0, D1, ...)
template
<
index_t
NDimSpatial
,
template
<
index_t
NDimSpatial
,
typename
A
Layout
,
typename
In
Layout
,
typename
B
Layout
,
typename
Wei
Layout
,
typename
C
Layout
,
typename
Out
Layout
,
typename
A
DataType
,
typename
In
DataType
,
typename
B
DataType
,
typename
Wei
DataType
,
typename
C
DataType
,
typename
Out
DataType
,
typename
A
ElementwiseOperation
,
typename
In
ElementwiseOperation
,
typename
B
ElementwiseOperation
,
typename
Wei
ElementwiseOperation
,
typename
C
ElementwiseOperation
>
typename
Out
ElementwiseOperation
>
struct
DeviceGroupedConvFwd
:
public
BaseOperator
struct
DeviceGroupedConvFwd
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_
a
,
// input image
MakeArgumentPointer
(
const
void
*
p_
in
,
// input image
const
void
*
p_
b
,
// weight
const
void
*
p_
wei
,
// weight
void
*
p_
c
,
// output image
void
*
p_
out
,
// output image
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a
_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in
_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a
_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in
_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b
_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei
_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b
_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei
_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c
_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out
_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c
_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out
_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
A
ElementwiseOperation
&
a
_element_op
,
const
In
ElementwiseOperation
&
in
_element_op
,
const
B
ElementwiseOperation
&
b
_element_op
,
const
Wei
ElementwiseOperation
&
wei
_element_op
,
const
C
ElementwiseOperation
&
c
_element_op
)
=
0
;
const
Out
ElementwiseOperation
&
out
_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
9b3365e1
...
@@ -67,6 +67,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -67,6 +67,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
OutElementwiseOperation
>
{
{
static
constexpr
ck
::
index_t
NDimSpatial
=
2
;
using
DeviceOp
=
using
DeviceOp
=
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
...
@@ -107,18 +109,18 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -107,18 +109,18 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static
constexpr
auto
BBlockLdsN0PerBlock
=
NPerBlock
/
BBlockLdsN1PerBlock
;
static
constexpr
auto
BBlockLdsN0PerBlock
=
NPerBlock
/
BBlockLdsN1PerBlock
;
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -390,13 +392,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -390,13 +392,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
...
@@ -473,11 +475,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -473,11 +475,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
index_t
Conv_N_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
index_t
k_batch_
;
};
};
...
@@ -682,13 +684,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -682,13 +684,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
@@ -724,13 +726,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -724,13 +726,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv
nd
_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_
grouped_
conv_bwd_weight_
g
nwc_
g
kxc_
g
nwk_xdl_cshuffle.hpp
View file @
9b3365e1
...
@@ -4,13 +4,14 @@
...
@@ -4,13 +4,14 @@
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <numeric>
#include <sstream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
...
@@ -20,6 +21,104 @@ namespace ck {
...
@@ -20,6 +21,104 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_xdlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
batch_count
;
ignore
=
block_2_ctile_map
;
ignore
=
compute_ptr_offset_of_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
InDataType
,
...
@@ -57,21 +156,21 @@ template <ck::index_t NDimSpatial,
...
@@ -57,21 +156,21 @@ template <ck::index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle
struct
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle
:
public
DeviceConvBwdWeight
<
:
public
Device
Grouped
ConvBwdWeight
<
NDimSpatial
,
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
G
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
tensor_layout
::
convolution
::
G
NDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
G
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tensor_layout
::
convolution
::
G
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
G
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
ck
::
tensor_layout
::
convolution
::
G
NDHWK
>>
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
...
@@ -79,7 +178,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -79,7 +178,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
OutElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle
;
using
DeviceOp
=
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle
;
using
ADataType
=
OutDataType
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
BDataType
=
InDataType
;
...
@@ -117,18 +216,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -117,18 +216,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -269,18 +368,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -269,18 +368,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -436,18 +535,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -436,18 +535,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -664,8 +763,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -664,8 +763,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
}
template
<
index_t
Dim
>
template
<
index_t
Dim
>
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
static
auto
MakeDescriptor_M0
(
const
std
::
array
<
index_t
,
Dim
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
const
std
::
array
<
index_t
,
Dim
>&
stride
,
index_t
gridSize
,
index_t
gridSize
,
index_t
blockSize
)
index_t
blockSize
)
{
{
...
@@ -759,16 +858,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -759,16 +858,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
...
@@ -783,11 +883,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -783,11 +883,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
G
},
Conv_N_
{
N
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_K_
{
K
},
Conv_C_
{
C
},
Conv_C_
{
C
},
...
@@ -819,6 +921,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -819,6 +921,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
...
@@ -836,21 +958,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -836,21 +958,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
InElementwiseOperation
a_element_op_
;
InElementwiseOperation
a_element_op_
;
OutElementwiseOperation
b_element_op_
;
OutElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
index_t
k_batch_
;
};
};
...
@@ -873,14 +1003,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -873,14 +1003,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{
"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
@@ -891,7 +1019,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -891,7 +1019,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
...
@@ -900,17 +1028,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -900,17 +1028,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
const
auto
kernel
=
kernel_
batched_
gemm_xdlops_bwd_weight
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
OutElementwiseOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -921,13 +1050,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -921,13 +1050,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
...
@@ -998,16 +1129,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -998,16 +1129,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
@@ -1016,6 +1148,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -1016,6 +1148,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
G
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -1040,16 +1173,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -1040,16 +1173,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
@@ -1058,6 +1192,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -1058,6 +1192,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -1086,7 +1221,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
...
@@ -1086,7 +1221,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle"
str
<<
"Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
9b3365e1
...
@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
index_t
M01
=
1
,
index_t
M01
=
1
,
index_t
N01
=
1
,
index_t
N01
=
1
,
index_t
KSplit
=
1
)
index_t
KSplit
=
1
)
:
M01_
(
M01
),
:
c_grid_desc_m_n_
(
c_grid_desc_m_n
),
M01_
(
M01
),
N01_
(
N01
),
N01_
(
N01
),
KSplit_
(
KSplit
),
KSplit_
(
KSplit
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
,
N01
,
KSplit
))
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
,
N01
,
KSplit
))
{
{
}
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
__device__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
...
@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
{
return
underlying_map_
.
CalculateBottomIndex
(
idx_top
);
static_assert
(
TopIdx
::
Size
()
==
1
);
return
underlying_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
I0
]
%
CalculateGridSize
()));
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
}
}
private:
private:
__device__
constexpr
index_t
CalculateGridSize
()
const
{
return
CalculateGridSize
(
c_grid_desc_m_n_
);
}
__host__
static
constexpr
auto
GetBlockToCTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
static
constexpr
auto
GetBlockToCTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
M01
,
index_t
N01
,
index_t
N01
,
...
@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
;
}
}
CGridDesc_M_N
c_grid_desc_m_n_
;
index_t
M01_
,
N01_
,
KSplit_
;
index_t
M01_
,
N01_
,
KSplit_
;
using
UnderlyingMap
=
decltype
(
GetBlockToCTileMap
(
CGridDesc_M_N
{},
1
,
1
,
1
));
using
UnderlyingMap
=
decltype
(
GetBlockToCTileMap
(
CGridDesc_M_N
{},
1
,
1
,
1
));
UnderlyingMap
underlying_map_
;
UnderlyingMap
underlying_map_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
9b3365e1
...
@@ -289,7 +289,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
...
@@ -289,7 +289,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
XDataType
,
XDataType
,
decltype
(
thread_buffer_desc_m_k
),
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
GridDesc_M_K
,
YElementwiseOperation
,
PassThrough
,
ThreadBufferLengths_M_K
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorDim
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
9b3365e1
...
@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
else
if
constexpr
(
NDimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
std
::
size_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
]
;
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
{
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
3
]
;
++
ho
)
for
(
std
::
size_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
4
]
;
++
wo
)
for
(
std
::
size_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
{
auto
wi
=
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
9b3365e1
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
Tensor
<
ComputeDataType
>
avg_acc_sq
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc_sq
({
M
});
Tensor
<
ComputeDataType
>
avg_acc
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc
({
M
});
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
// reduce N dim
// reduce N dim
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
9b3365e1
...
@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
{
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
}
}
}
}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
9b3365e1
...
@@ -95,7 +95,7 @@ template <typename Activation>
...
@@ -95,7 +95,7 @@ template <typename Activation>
using
Add_Activation_Mul_Clamp
=
using
Add_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
DeviceOp
>
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceFactory
;
struct
DeviceOperationInstanceFactory
;
}
// namespace instance
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
deleted
100644 → 0
View file @
9608beee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
View file @
9b3365e1
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data
_multiple_d
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
@@ -17,46 +17,54 @@ namespace instance {
...
@@ -17,46 +17,54 @@ namespace instance {
// conv2d backward data
// conv2d backward data
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWC
,
GNHWK
,
GKYXC
,
GKYXC
,
GNHWK
,
Empty_Tuple
,
F16
,
GNHWC
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
Empty_Tuple
,
PassThrough
,
F16
,
PassThrough
>>>&
instances
);
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiLayout
,
typename
InLayout
,
typename
OutDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
>
typename
InDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdData
<
struct
DeviceOperationInstanceFactory
<
NumDimSpatial
,
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdDataMultipleD
<
InLayout
,
NumDimSpatial
,
WeiLayout
,
OutLayout
,
OutLayout
,
WeiLayout
,
InDataType
,
Empty_Tuple
,
WeiDataType
,
InLayout
,
OutDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
WeiDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Empty_Tuple
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
{
using
DeviceOp
=
DeviceGroupedConvBwdData
<
NumDimSpatial
,
using
DeviceOp
=
InLayout
,
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
WeiLayout
,
OutLayout
,
OutLayout
,
WeiLayout
,
InDataType
,
Empty_Tuple
,
WeiDataType
,
InLayout
,
OutDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
WeiDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Empty_Tuple
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
0 → 100644
View file @
9b3365e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment