Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
be58e518
Commit
be58e518
authored
Aug 06, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
94642acf
afbf6350
Changes
49
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2313 additions
and
186 deletions
+2313
-186
CMakeLists.txt
CMakeLists.txt
+22
-15
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
+2
-2
example/12_reduce/reduce_blockwise_impl.hpp
example/12_reduce/reduce_blockwise_impl.hpp
+12
-2
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
+34
-11
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+23
-0
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+79
-8
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+154
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+159
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+158
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+158
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1054
-0
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
...on/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+10
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+10
-1
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+417
-91
No files found.
CMakeLists.txt
View file @
be58e518
...
@@ -106,21 +106,33 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
...
@@ -106,21 +106,33 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
message
(
"checking which targets are supported"
)
message
(
"checking which targets are supported"
)
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
#These targets will be filtered and only supported ones will be used
#These targets will be filtered and only supported ones will be used
#Setting GPU_TARGETS on command line will override this list
#Setting GPU_TARGETS on command line will override this list
if
(
NOT PROFILER_ONLY
)
if
(
NOT PROFILER_ONLY
)
if
(
NOT ENABLE_ASAN_PACKAGING
)
if
(
NOT ENABLE_ASAN_PACKAGING
)
#build CK for all supported targets
#build CK for all supported targets
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
LESS 600300000
)
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
else
()
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
#build CK only for xnack-supported targets
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
else
()
TARGETS
"gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
)
endif
()
endif
()
else
()
#build CK only for xnack-supported targets
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+"
)
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
endif
()
else
()
else
()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
...
@@ -169,11 +181,6 @@ endif()
...
@@ -169,11 +181,6 @@ endif()
# CK config file to record supported datatypes, etc.
# CK config file to record supported datatypes, etc.
configure_file
(
include/ck/config.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/include/ck/config.h
)
configure_file
(
include/ck/config.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/include/ck/config.h
)
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 500723302
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 500723302
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
add_compile_options
(
-fno-offload-uniform-block
)
...
...
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
...
@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
...
@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline
HostTensorDescriptor
inline
HostTensorDescriptor
make_r0_host_tensor_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
problem_size
)
make_r0_host_tensor_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
problem_size
)
{
{
std
::
vector
<
ck
::
index_t
>
dimensions
{
problem_size
.
G_
,
problem_size
.
N_
};
std
::
vector
<
ck
::
long_
index_t
>
dimensions
{
problem_size
.
G_
,
problem_size
.
N_
};
ck
::
ranges
::
copy
(
problem_size
.
output_spatial_lengths_
,
std
::
back_inserter
(
dimensions
));
ck
::
ranges
::
copy
(
problem_size
.
output_spatial_lengths_
,
std
::
back_inserter
(
dimensions
));
...
...
example/12_reduce/reduce_blockwise_impl.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
...
@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
int
log_level
=
0
,
cold_niters
=
5
,
nrepeat
=
50
;
if
(
beta
!=
0.0
f
)
{
std
::
cerr
<<
"Warning: With beta != 0.0f there must be only one repeat for correct results "
"since out memory is being overwritten."
<<
std
::
endl
;
cold_niters
=
0
;
nrepeat
=
1
;
}
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
log_level
,
cold_niters
,
nrepeat
});
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InOutDataType
)
+
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InOutDataType
)
+
invariant_total_length
*
sizeof
(
InOutDataType
);
invariant_total_length
*
sizeof
(
InOutDataType
);
...
...
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
...
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero
// reset input to zero
in_device_buf
.
SetZero
();
in_device_buf
.
SetZero
();
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
input_left_pads_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
input_right_pads_i32
(
NDimSpatial
);
for
(
ck
::
index_t
d
=
0
;
d
<
NDimSpatial
;
d
++
)
{
input_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_spatial_lengths_
[
d
]);
filter_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
filter_spatial_lengths_
[
d
]);
output_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
GetOutputSpatialLengths
()[
d
]);
conv_filter_strides_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
conv_filter_strides_
[
d
]);
conv_filter_dilations_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
conv_filter_dilations_
[
d
]);
input_left_pads_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_left_pads_
[
d
]);
input_right_pads_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_right_pads_
[
d
]);
}
// do GEMM
// do GEMM
auto
conv
=
DeviceConvNdBwdDataInstance
{};
auto
conv
=
DeviceConvNdBwdDataInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
invoker
=
conv
.
MakeInvoker
();
...
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
...
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv
.
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
conv
.
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
conv_param
.
N_
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
N_
)
,
conv_param
.
K_
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
K_
)
,
conv_param
.
C_
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
C_
)
,
conv_param
.
input_spatial_lengths_
,
input_spatial_lengths_
i32
,
conv_param
.
filter_spatial_lengths_
,
filter_spatial_lengths_
i32
,
conv_param
.
GetO
utput
S
patial
L
engths
()
,
o
utput
_s
patial
_l
engths
_i32
,
conv_param
.
conv_filter_strides_
,
conv_filter_strides_
i32
,
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations_
i32
,
conv_param
.
input_left_pads_
,
input_left_pads_
i32
,
conv_param
.
input_right_pads_
,
input_right_pads_
i32
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
);
out_element_op
);
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
View file @
be58e518
...
@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
...
@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_a
,
BPointers
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
be58e518
...
@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
__host__
__device__
static
auto
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeAGridDescriptor_M_K
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
...
@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template
<
typename
BLay
>
template
<
typename
BLay
>
__host__
__device__
static
auto
__host__
__device__
static
auto
MakeBGridDescriptor_N_K
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeBGridDescriptor_N_K
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template
<
typename
ELay
>
template
<
typename
ELay
>
__host__
__device__
static
auto
__host__
__device__
static
auto
MakeEGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeEGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
...
@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static
auto
MakeDsGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
// desc for problem definition
// desc for problem definition
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_M_K
=
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_N_K
=
using
BGridDesc_N_K
=
...
@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
GemmToConv
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
...
@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
...
@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
b_element_op
,
b_element_op
,
cde_element_op
};
cde_element_op
};
}
}
static
__device__
__host__
auto
MakeArgument
(
APointers
p_as
,
BPointers
p_bs
,
const
ck
::
Array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
View file @
be58e518
...
@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
...
@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
static
constexpr
auto
spatial_offset
=
Number
<
3
>
{};
static
constexpr
auto
spatial_offset
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
;
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpecialization
::
MKPadding
,
index_t
,
index_t
,
index_t
>
{
MatrixPadder
<
GemmSpecialization
::
MKPadding
,
index_t
,
index_t
,
index_t
>
{
...
@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
...
@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
:
independent_filter_stride
;
:
independent_filter_stride
;
}
}
GemmToConv
FwdTransformer
conv_to_gemm_transformer
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer
{
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
{},
// not needed for A Descriptor
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
be58e518
...
@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
...
@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeBGridDescriptor_BK0_N_BK1
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
}
template
<
typename
ELay
>
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
...
@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return
out_gemmm_gemmn_desc
;
return
out_gemmm_gemmn_desc
;
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
}
// desc for problem definition
// desc for problem definition
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
...
@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
...
@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
...
@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op
};
cde_element_op
};
}
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op
);
cde_element_op
);
}
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
be58e518
...
@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
...
@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeBGridDescriptor_BK0_N_BK1
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
}
template
<
typename
CLay
>
template
<
typename
CLay
>
static
auto
MakeCGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeCGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
CLay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
CLay
>();
...
@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
}
// desc for problem definition
// desc for problem definition
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
...
@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
be58e518
...
@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
true
/*SplitN*/
,
ALayout
,
ALayout
,
...
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
...
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_N_K
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeBGridDescriptor_N_K
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
template
<
typename
ELay
>
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
...
@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static
auto
MakeDsGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
// desc for problem definition
// desc for problem definition
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_M_K
=
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_N_K
=
using
BGridDesc_N_K
=
...
@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
...
@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
index_t
conv_N_per_block_
;
index_t
conv_N_per_block_
;
...
@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op
};
cde_element_op
};
}
}
static
auto
MakeArgument
(
APointers
p_as
,
BPointers
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_a
,
APointers
p_a
s
,
BPointers
p_b
,
BPointers
p_b
s
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
...
@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
return
std
::
make_unique
<
Argument
>
(
p_a
s
,
p_b
,
p_b
s
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_lengths
,
...
@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op
);
cde_element_op
);
}
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
APointers
p_as
,
BPointers
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
be58e518
...
@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
true
/*SplitN*/
,
ADataType
,
ADataType
,
...
@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
...
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
MakeBGridDescriptor_BK0_N_BK1
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
template
<
typename
ELay
>
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
...
@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
// desc for problem definition
// desc for problem definition
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
EGridDesc_M_N
=
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
...
@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
index_t
conv_N_per_block_
;
index_t
conv_N_per_block_
;
...
@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
false
;
return
false
;
}
}
// Gridwise gemm v3 doesn't verify descriptors size
if
(
!
arg
.
conv_to_gemm_transformer_
.
AreDescriptorsSmallerThan2GB
())
{
return
false
;
}
// check Gridwise GEMM
// check Gridwise GEMM
const
index_t
GemmM
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
GemmM
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
...
@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op
};
cde_element_op
};
}
}
static
auto
MakeArgument
(
const
void
*
p_as
,
const
void
*
p_bs
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
...
@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op
);
cde_element_op
);
}
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
be58e518
...
@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
...
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
}
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_N_K
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeBGridDescriptor_N_K
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
}
template
<
typename
ELay
>
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
...
@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return
GetPaddedRGridDescriptor
(
r_grid_desc_mraw
,
NHoWo
);
return
GetPaddedRGridDescriptor
(
r_grid_desc_mraw
,
NHoWo
);
}
}
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc_M_K
=
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
dummy_conv_to_gemm_transformer
))
>
;
using
BGridDesc_N_K
=
using
BGridDesc_N_K
=
...
@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
GemmToConv
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
...
@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
be58e518
...
@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static
constexpr
auto
BEnableLds
=
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
BEnableLds_auto
||
BEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
using
GemmToConv
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeAGridDescriptor
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
in_gemmmraw_gemmkraw_desc
=
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>();
...
@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
}
template
<
typename
BLay
>
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeBGridDescriptor
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>();
...
@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
}
template
<
typename
ELay
>
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
const
auto
out_gemmmraw_gemmnraw_desc
=
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>();
...
@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return
out_gemmm_gemmn_desc
;
return
out_gemmm_gemmn_desc
;
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
GemmToConv
FwdTransformer
&
conv_to_gemm_transformer
)
static
auto
MakeDsGridDescriptor_M_N
(
const
ConvToGemm
FwdTransformer
&
conv_to_gemm_transformer
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
}
// desc for problem definition
// desc for problem definition
constexpr
static
GemmToConv
FwdTransformer
dummy_conv_to_gemm_transformer
;
constexpr
static
ConvToGemm
FwdTransformer
dummy_conv_to_gemm_transformer
;
using
AGridDesc
=
using
AGridDesc
=
decltype
(
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
(
dummy_conv_to_gemm_transformer
));
decltype
(
DeviceOp
::
MakeAGridDescriptor
<
ALayout
>
(
dummy_conv_to_gemm_transformer
));
using
BGridDesc
=
using
BGridDesc
=
...
@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
...
@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
num_group_
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer_
;
ConvToGemm
FwdTransformer
conv_to_gemm_transformer_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
...
@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op
};
cde_element_op
};
}
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op
);
cde_element_op
);
}
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
long_index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
long_index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_lengths
);
array_convert
(
b_g_k_c_xs_strides_i32
,
b_g_k_c_xs_strides
);
for
(
index_t
d
=
0
;
d
<
NumDTensor
;
d
++
)
{
array_convert
(
ds_g_n_k_wos_lengths_i32
[
d
],
ds_g_n_k_wos_lengths
[
d
]);
array_convert
(
ds_g_n_k_wos_strides_i32
[
d
],
ds_g_n_k_wos_strides
[
d
]);
}
array_convert
(
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_lengths
);
array_convert
(
e_g_n_k_wos_strides_i32
,
e_g_n_k_wos_strides
);
array_convert
(
conv_filter_strides_i32
,
conv_filter_strides
);
array_convert
(
conv_filter_dilations_i32
,
conv_filter_dilations
);
array_convert
(
input_left_pads_i32
,
input_left_pads
);
array_convert
(
input_right_pads_i32
,
input_right_pads
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_strides_i32
,
b_g_k_c_xs_lengths_i32
,
b_g_k_c_xs_strides_i32
,
ds_g_n_k_wos_lengths_i32
,
ds_g_n_k_wos_strides_i32
,
e_g_n_k_wos_lengths_i32
,
e_g_n_k_wos_strides_i32
,
conv_filter_strides_i32
,
conv_filter_dilations_i32
,
input_left_pads_i32
,
input_right_pads_i32
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
0 → 100644
View file @
be58e518
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
be58e518
...
@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
...
@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
GemmToConv
FwdTransformer
=
using
ConvToGemm
FwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
;
TransformConvFwdToGemm
<
NDimSpatial
,
ConvolutionForwardSpecialization
::
Default
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
...
@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
...
@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths
[
I2
]
=
C
;
b_g_k_c_xs_lengths
[
I2
]
=
C
;
c_g_n_k_wos_lengths
[
I1
]
=
N
;
c_g_n_k_wos_lengths
[
I1
]
=
N
;
GemmToConv
FwdTransformer
conv_to_gemm_transformer
{
a_g_n_c_wis_lengths
,
ConvToGemm
FwdTransformer
conv_to_gemm_transformer
{
a_g_n_c_wis_lengths
,
image_g_n_c_wis_strides
,
image_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
{},
// not needed for A Descriptor
{},
// not needed for A Descriptor
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
...
@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
if
(
thread_k_cluster_id
==
0
)
if
(
thread_k_cluster_id
==
0
)
{
{
if
(
block_group_size
==
0
&&
!
float_equal_zero
{}(
beta_values
[
iR
]))
if
(
!
float_equal_zero
{}(
beta_values
[
iR
]))
{
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
priorDstValueBuf
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -244,7 +244,7 @@ struct GridwiseReduction_mk_to_m_multiblock
...
@@ -244,7 +244,7 @@ struct GridwiseReduction_mk_to_m_multiblock
if
(
thread_k_cluster_id
==
0
)
if
(
thread_k_cluster_id
==
0
)
{
{
if
(
block_group_size
==
0
&&
!
float_equal_zero
{}(
beta
))
if
(
!
float_equal_zero
{}(
beta
))
{
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
priorDstValueBuf
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
...
@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatAB
)
<=
TwoGB
&&
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatAB
)
<=
TwoGB
&&
c_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
FloatC
)
<=
TwoGB
))
{
return
false
;
}
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
be58e518
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
...
@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
const
BGridDesc_B_K0_N_K1
&
b_grid_desc_b_k0_n_k1
,
const
BGridDesc_B_K0_N_K1
&
b_grid_desc_b_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_b_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatAB
)
<=
TwoGB
&&
b_grid_desc_b_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatAB
)
<=
TwoGB
&&
c_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
FloatC
)
<=
TwoGB
))
{
return
false
;
}
const
auto
M
=
a_grid_desc_b_k0_m_k1
.
GetLength
(
I2
);
const
auto
M
=
a_grid_desc_b_k0_m_k1
.
GetLength
(
I2
);
const
auto
N
=
b_grid_desc_b_k0_n_k1
.
GetLength
(
I2
);
const
auto
N
=
b_grid_desc_b_k0_n_k1
.
GetLength
(
I2
);
const
auto
K0
=
a_grid_desc_b_k0_m_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_b_k0_m_k1
.
GetLength
(
I1
);
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
be58e518
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment