Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8dd7156d
Commit
8dd7156d
authored
Jul 25, 2023
by
ltqin
Browse files
Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask
parents
d5f629e7
b5a3ea2d
Changes
533
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
322 additions
and
1814 deletions
+322
-1814
example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp
+62
-0
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
+62
-0
example/50_put_element/CMakeLists.txt
example/50_put_element/CMakeLists.txt
+1
-0
example/50_put_element/put_element_fp16.cpp
example/50_put_element/put_element_fp16.cpp
+88
-0
include/ck/ck.hpp
include/ck/ck.hpp
+56
-22
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+1
-1
include/ck/host_utility/hip_check_error.hpp
include/ck/host_utility/hip_check_error.hpp
+1
-1
include/ck/host_utility/io.hpp
include/ck/host_utility/io.hpp
+1
-1
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+6
-5
include/ck/host_utility/stream_utility.hpp
include/ck/host_utility/stream_utility.hpp
+43
-0
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
...ckward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
+0
-275
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
...ward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
+0
-355
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
...ht_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
+0
-150
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...rd_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+0
-132
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
...ht_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
+0
-150
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+0
-135
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
+0
-147
include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp
...ward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp
+1
-1
include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
+0
-260
include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...orm_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+0
-179
No files found.
Too many changes to show.
To preserve performance only
533 of 533+
files are displayed.
Plain diff
Email patch
example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "maxpool2d_bwd_common.hpp"
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
IndexDataType
=
int32_t
;
using
ComputeDataType
=
float
;
using
DInDataType
=
ck
::
half_t
;
using
DOutDataType
=
ck
::
half_t
;
static
constexpr
bool
PropagateNan
=
false
;
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
// Pool shape
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
window_stride_h
=
1
;
ck
::
index_t
window_stride_w
=
1
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
bool
pass
=
maxpool_bwd_test
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
DInDataType
,
DOutDataType
,
PropagateNan
>
(
do_verification
,
time_kernel
,
N
,
C
,
Y
,
X
,
Hi
,
Wi
,
window_stride_h
,
window_stride_w
,
in_left_pad_h
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_w
);
return
(
pass
?
0
:
1
);
}
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "maxpool2d_bwd_common.hpp"
using
InDataType
=
float
;
using
OutDataType
=
float
;
using
IndexDataType
=
int32_t
;
using
ComputeDataType
=
float
;
using
DInDataType
=
float
;
using
DOutDataType
=
float
;
static
constexpr
bool
PropagateNan
=
false
;
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
// Pool shape
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
bool
pass
=
maxpool_bwd_test
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
DInDataType
,
DOutDataType
,
PropagateNan
>
(
do_verification
,
time_kernel
,
N
,
C
,
Y
,
X
,
Hi
,
Wi
,
window_stride_h
,
window_stride_w
,
in_left_pad_h
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_w
);
return
(
pass
?
0
:
1
);
}
example/50_put_element/CMakeLists.txt
0 → 100644
View file @
8dd7156d
add_example_executable
(
example_put_element_fp16 put_element_fp16.cpp
)
example/50_put_element/put_element_fp16.cpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using
XDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
IndexDataType
=
int32_t
;
using
YElementwiseOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DevicePutElementImpl
<
XDataType
,
// XDataType
IndexDataType
,
// IndexDataType
YDataType
,
// YDataType
YElementwiseOp
,
ck
::
InMemoryDataOperationEnum
::
Set
,
1
>
;
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
int
N
=
1024
;
Tensor
<
XDataType
>
x
(
HostTensorDescriptor
{
N
,
1
});
Tensor
<
IndexDataType
>
indices
(
HostTensorDescriptor
{
N
,
1
});
Tensor
<
YDataType
>
y
(
HostTensorDescriptor
{
N
,
1
});
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
-
1.0
,
1.0
});
for
(
int
i
=
0
;
i
<
N
;
++
i
)
indices
(
i
)
=
i
;
DeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
indices_device_buf
(
sizeof
(
IndexDataType
)
*
indices
.
mDesc
.
GetElementSpaceSize
());
x_device_buf
.
ToDevice
(
x
.
mData
.
data
());
indices_device_buf
.
ToDevice
(
indices
.
mData
.
data
());
auto
put_instance
=
DeviceInstance
{};
auto
put_invoker_ptr
=
put_instance
.
MakeInvokerPointer
();
auto
put_argument_ptr
=
put_instance
.
MakeArgumentPointer
(
static_cast
<
XDataType
*>
(
x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
YDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
N
,
N
,
YElementwiseOp
{});
if
(
!
put_instance
.
IsSupportedArgument
(
put_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"argument is not supported!"
);
}
float
ave_time
=
put_invoker_ptr
->
Run
(
put_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
Tensor
<
YDataType
>
y_host
(
HostTensorDescriptor
{
N
,
1
});
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
IndexDataType
idx
=
indices
(
i
);
y_host
(
idx
)
=
x
(
i
);
}
y_device_buf
.
FromDevice
(
y
.
mData
.
data
());
pass
=
ck
::
utils
::
check_err
(
y
,
y_host
);
}
return
(
pass
?
0
:
1
);
}
include/ck/ck.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -18,36 +18,49 @@
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
// for most kernels
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 2
// for wavelet GEMM kernel
#define CK_WAVELET_MAX_THREAD_PER_BLOCK 512
#define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
// check GPU target
#ifdef __HIP_DEVICE_COMPILE__
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
#error Not supported target
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x100
20
000
#elif defined(__gfx1100__)
|| defined(__gfx1101__) || defined(__gfx1102__)
// for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x
3
100
4
000
#endif
// FMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx
1030
__) // for GPU code
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) ||
defined(__gfx1030__) ||
\
defined(__gfx
940__) || defined(__gfx941__) || defined(__gfx942
__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
...
...
@@ -56,18 +69,23 @@
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_MFMA
#endif
#if
defined(__gfx90a__)
#if
(
defined(__gfx90a__)
|| defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) // for GPU code
#elif defined(__gfx1100__)
|| defined(__gfx1101__) || defined(__gfx1102__)
// for GPU code
#define CK_USE_AMD_WMMA
#endif
...
...
@@ -83,13 +101,15 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if defined(__gfx90a__) // for GPU code
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
...
...
@@ -143,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...
...
@@ -163,16 +187,26 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1
#else // __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0
#endif // __gfx908__
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_383542 1
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_SWDEV_3318619 0
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#elif
// enable only on MI200
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/host_utility/device_prop.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/host_utility/hip_check_error.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/host_utility/io.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -20,6 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
...
...
@@ -29,15 +30,15 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
z
);
const
int
nrepeat
=
10
;
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
const
int
nrepeat
=
10
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
...
...
include/ck/host_utility/stream_utility.hpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
static
inline
int
getAvailableComputeUnitCount
(
const
StreamConfig
&
stream_config
)
{
constexpr
int
MAX_MASK_DWORDS
=
64
;
// assume at most 64*32 = 2048 CUs
uint32_t
cuMask
[
MAX_MASK_DWORDS
];
for
(
int
i
=
0
;
i
<
MAX_MASK_DWORDS
;
i
++
)
cuMask
[
i
]
=
0
;
auto
countSetBits
=
[](
uint32_t
dword
)
{
int
count
=
0
;
while
(
dword
!=
0
)
{
if
(
dword
&
0x1
)
count
++
;
dword
=
dword
>>
1
;
};
return
(
count
);
};
hip_check_error
(
hipExtStreamGetCUMask
(
stream_config
.
stream_id_
,
MAX_MASK_DWORDS
,
&
cuMask
[
0
]));
int
ret
=
0
;
for
(
int
i
=
0
;
i
<
MAX_MASK_DWORDS
;
i
++
)
ret
+=
countSetBits
(
cuMask
[
i
]);
return
(
ret
);
};
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// Number of GEMMs = YTilde * XTilde
// GemmM = C
// GemmN = N * HTildeSlice * WTildeSlice
// GemmK = K * YDotSlice * XDotSlice
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
IYTildeValue
,
index_t
IXTildeValue
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
IYTildeValue
>
,
Number
<
IXTildeValue
>
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
constexpr
auto
IYTilde
=
Number
<
IYTildeValue
>
{};
constexpr
auto
IXTilde
=
Number
<
IXTildeValue
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
IYTilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
IXTilde
,
XTilde
);
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
// weight tensor
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTilde
),
make_freeze_transform
(
IXTilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
#if 1
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
// output tensor
// this add padding check
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
#if 1
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
5
,
1
,
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYTilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
IXTilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
wei_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: out
// B: wei
// C: in
// Number of GEMMs = YTilde * XTilde
// GemmM = N * HTildeSlice * WTildeSlice
// GemmN = C
// GemmK = K * YDotSlice * XDotSlice
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
IYTilde
,
typename
IXTilde
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
IYTilde
i_ytilde
,
IXTilde
i_xtilde
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
// A: output tensor
// this add padding check
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
#if 1
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
5
,
1
,
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
// B: weight tensor
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
#if 1
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
// A: out
// B: wei
// C: in
// Number of GEMMs = 1
// GemmM = N * Ho * Wo
// GemmN = C
// GemmK = K
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1
(
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
/* wei_k_y_x_c_grid_desc */
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
C
*
Y
*
X
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
C
*
Y
*
X
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: output tensor
const
auto
out_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: output tensor
const
auto
out_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: out
// B: in
// C: wei
// GemmM = K
// GemmN = Y * X * C
// GemmKTotal = N * Ho * Wo
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
...
...
include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
);
}
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
assert
(
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
);
}
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
assert
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
);
}
}
// namespace ck
#endif
include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
d5f629e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_grid_desc
,
in_gemmk_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
);
}
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
__device__
constexpr
auto
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
assert
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
);
// weight tensor
const
auto
wei_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
C
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_grid_desc
,
in_gemmk_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
Prev
1
…
14
15
16
17
18
19
20
21
22
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment