Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4100d1d8
Commit
4100d1d8
authored
Aug 23, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into migx-flash-attn
parents
48717006
c8a8385f
Changes
609
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1034 additions
and
73 deletions
+1034
-73
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+13
-11
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+1
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+15
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
...n/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
+164
-0
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+1
-2
include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
...ude/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
+39
-0
include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp
...de/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp
+64
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+10
-11
include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp
.../ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp
+32
-0
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
+10
-7
include/ck/tensor_operation/gpu/device/device_put_element.hpp
...ude/ck/tensor_operation/gpu/device/device_put_element.hpp
+36
-0
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+11
-6
include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp
include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp
+18
-0
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp
...tion/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp
+575
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+13
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+15
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+1
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+13
-11
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
4100d1d8
...
...
@@ -11,7 +11,7 @@
namespace
ck
{
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are vis
a
ble to the whole block, C is distributed among each thread
// A and B are vis
i
ble to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
4100d1d8
...
...
@@ -35,8 +35,8 @@ struct BlockwiseSoftmax
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadwiseMaxReduce
=
typename
conditional
<
IgnoreNaN
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
View file @
4100d1d8
...
...
@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
namespace
ck
{
...
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
CountDataType
>
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
Merge
(
T
&
mean_a
,
T
&
var_a
,
CountDataType
&
count_a
,
T
mean_b
,
T
var_b
,
CountDataType
count_b
)
{
int
count
=
count_a
+
count_b
;
CountDataType
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
...
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a
=
count
;
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
template
<
typename
CountDataType
>
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
CountDataType
&
count
)
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
__shared__
CountDataType
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
...
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
CountDataType
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
CountDataType
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
4100d1d8
...
...
@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
4100d1d8
...
...
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
4100d1d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
template
Run
<
SrcBuffer
,
DstBuffer
,
DstInMemOp
>(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1r2
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
4100d1d8
...
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch
(
s
)
{
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
return
"FFilter1x1Stride1Pad0"
;
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
default:
return
"Unrecognized specialization!"
;
}
}
...
...
include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
0 → 100644
View file @
4100d1d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NDimSpatial
,
typename
DOutDataType
,
typename
DInDataType
,
typename
DOutLayout
,
typename
DInLayout
>
struct
DeviceAvgPoolBwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
void
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_k_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
dout_n_k_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_k_wos_length
,
std
::
vector
<
ck
::
index_t
>
din_n_k_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_k_c_xs_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp
0 → 100644
View file @
4100d1d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmStreamK
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
NumSKBlocks
=
0
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmStreamKPtr
=
std
::
unique_ptr
<
DeviceGemmStreamK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
4100d1d8
...
...
@@ -27,17 +27,16 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
const
void
*
p_out
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp
0 → 100644
View file @
4100d1d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// For pooling which used indexable operation, such as MaxPool, MinPool...etc
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
struct
DeviceIndexPoolBwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
const
void
*
p_indices
,
void
*
p_din
,
index_t
dout_length
,
index_t
din_length
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
View file @
4100d1d8
...
...
@@ -17,6 +17,8 @@ template <index_t InOutRank,
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DevicePoolFwd
:
public
BaseOperator
...
...
@@ -25,13 +27,14 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_n_c_wis_lengths
,
std
::
vector
<
ck
::
index_t
>
window_xs_lengths
,
std
::
vector
<
ck
::
index_t
>
output_n_c_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
input_n_c_wis_stride
,
std
::
vector
<
ck
::
index_t
>
output_n_c_wis_stride
,
std
::
vector
<
ck
::
index_t
>
indices_n_c_wis_stride
,
std
::
vector
<
ck
::
index_t
>
window_xs_strides
,
std
::
vector
<
ck
::
index_t
>
window_xs_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_put_element.hpp
0 → 100644
View file @
4100d1d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
Op
>
struct
DevicePutElement
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_input
,
const
void
*
p_indices
,
void
*
p_output
,
index_t
input_length
,
index_t
output_length
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
4100d1d8
...
...
@@ -18,7 +18,8 @@ template <typename InDataType,
typename
OutDataType
,
typename
InElementwiseOp
,
typename
AccElementwiseOp
,
index_t
Rank
>
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceSoftmax
:
public
BaseOperator
{
//
...
...
@@ -49,8 +50,6 @@ struct DeviceSoftmax : public BaseOperator
AccElementwiseOp
acc_elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
index_t
GetRank
()
const
=
0
;
virtual
index_t
GetNumReduceDim
()
const
=
0
;
};
template
<
typename
InDataType
,
...
...
@@ -58,9 +57,15 @@ template <typename InDataType,
typename
OutDataType
,
typename
InElementwiseOp
,
typename
AccElementwiseOp
,
index_t
Rank
>
using
DeviceSoftmaxPtr
=
std
::
unique_ptr
<
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
InElementwiseOp
,
AccElementwiseOp
,
Rank
>>
;
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceSoftmaxPtr
=
std
::
unique_ptr
<
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
InElementwiseOp
,
AccElementwiseOp
,
Rank
,
NumReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
library/
include/ck/
library/
tensor_operation
_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2
.hpp
→
include/ck/tensor_operation
/gpu/device/gemm_dl_algorithm
.hpp
View file @
4100d1d8
...
...
@@ -3,20 +3,16 @@
#pragma once
#include <vector>
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_softmax_i8_i8_rank3_reduce2_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
I8
,
F32
,
I8
,
PassThrough
,
PassThrough
,
3
>>&
instances
);
enum
struct
GemmDlAlgorithm
{
Default
,
// Uses DOT vector instructions
Dpp8
,
// Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp
0 → 100644
View file @
4100d1d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// In and Din = [N, C, Di, Hi, Wi]
// Out and Dout = [N, C, Do, Ho, Wo]
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
template
<
typename
DOutDataType
,
typename
DInDataType
,
typename
ComputeDataType
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DeviceAvgPool3dBwd_NDHWC_NDHWC
:
public
DeviceAvgPoolBwd
<
3
,
DOutDataType
,
DInDataType
,
tensor_layout
::
convolution
::
NDHWC
,
tensor_layout
::
convolution
::
NDHWC
>
{
static
constexpr
ck
::
index_t
NDimSpatial
=
3
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
Make3DGridDescriptor_Out_M_K_In_M
(
const
std
::
vector
<
ck
::
index_t
>&
dout_n_c_wos_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_c_wos_length
,
const
std
::
vector
<
ck
::
index_t
>&
dout_n_c_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_c_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
,
const
std
::
vector
<
ck
::
index_t
>&
tildes
)
{
index_t
i_ztilde
=
tildes
[
0
];
index_t
i_ytilde
=
tildes
[
1
];
index_t
i_xtilde
=
tildes
[
2
];
const
index_t
N
=
dout_n_c_wos_lengths
[
0
];
const
index_t
C
=
dout_n_c_wos_lengths
[
1
];
const
index_t
Di
=
din_n_c_wos_length
[
2
];
const
index_t
Hi
=
din_n_c_wos_length
[
3
];
const
index_t
Wi
=
din_n_c_wos_length
[
4
];
const
index_t
Do
=
dout_n_c_wos_lengths
[
2
];
const
index_t
Ho
=
dout_n_c_wos_lengths
[
3
];
const
index_t
Wo
=
dout_n_c_wos_lengths
[
4
];
const
index_t
Z
=
window_lengths
[
0
];
const
index_t
Y
=
window_lengths
[
1
];
const
index_t
X
=
window_lengths
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
ConvStrideD
=
window_strides
[
0
];
const
index_t
ConvStrideH
=
window_strides
[
1
];
const
index_t
ConvStrideW
=
window_strides
[
2
];
const
index_t
ConvDilationD
=
window_dilations
[
0
];
const
index_t
ConvDilationH
=
window_dilations
[
1
];
const
index_t
ConvDilationW
=
window_dilations
[
2
];
const
auto
out_n_do_ho_wo_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
C
),
make_tuple
(
dout_n_c_wos_strides
[
0
],
dout_n_c_wos_strides
[
2
],
dout_n_c_wos_strides
[
3
],
dout_n_c_wos_strides
[
4
],
dout_n_c_wos_strides
[
1
]));
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on Tildes that contribute to non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// ReduceK is different for each Reduce
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// Problem size of reduction kernel
const
index_t
MRaw
=
N
*
DTildeSlice
*
HTildeSlice
*
WTildeSlice
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
ZDotSlice
*
YDotSlice
*
XDotSlice
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// Out[ReduceM, ReduceK]
const
auto
out_n_dop_hop_wop_c_grid_desc
=
transform_tensor_descriptor
(
out_n_do_ho_wo_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
out_n_dop_hop_wop_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
make_tuple
(
-
ConvDilationD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
const
auto
out_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
,
C
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// In[ReduceM]
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
din_n_c_wos_strides
[
0
],
din_n_c_wos_strides
[
2
],
din_n_c_wos_strides
[
3
],
din_n_c_wos_strides
[
4
],
din_n_c_wos_strides
[
1
]));
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
XTilde
,
DTilde
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ztilde
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_grid_desc_reducemraw
=
transform_tensor_descriptor
(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
in_grid_desc_reducem
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
out_grid_desc_reducem_reducek
,
in_grid_desc_reducem
);
}
using
DoutDinGridDesc
=
decltype
(
Make3DGridDescriptor_Out_M_K_In_M
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}));
using
DoutGridDesc_M_K
=
remove_cvref_t
<
tuple_element_t
<
0
,
DoutDinGridDesc
>>
;
using
DinGridDesc_M
=
remove_cvref_t
<
tuple_element_t
<
1
,
DoutDinGridDesc
>>
;
// FIXME
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static
constexpr
index_t
OutSrcInDstVectorDim
=
0
;
// 0: M, 1: K
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
Div
=
tensor_operation
::
element_wise
::
UnaryDivide
;
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
DOutDataType
,
DInDataType
,
ComputeDataType
,
int
,
DoutGridDesc_M_K
,
DinGridDesc_M
,
reduce
::
Add
,
PassThrough
,
Div
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
OutSrcInDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DOutDataType
*
p_dout
,
DInDataType
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
:
p_dout_grid_
{
p_dout
},
p_din_grid_
{
p_din
},
dout_n_c_wos_lengths_
{
dout_n_c_wos_lengths
},
din_n_c_wos_length_
{
din_n_c_wos_length
},
dout_n_c_wos_strides_
{
dout_n_c_wos_strides
},
din_n_c_wos_strides_
{
din_n_c_wos_strides
},
num_reduce_
{
1
},
div_element_op_
{
window_lengths
[
0
]
*
window_lengths
[
1
]
*
window_lengths
[
2
]}
{
std
::
vector
<
ck
::
index_t
>
Tildes
(
NDimSpatial
);
for
(
int
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
int
GcdStrideDilation
=
math
::
gcd
(
window_strides
[
i
],
window_dilations
[
i
]);
Tildes
[
i
]
=
window_strides
[
i
]
/
GcdStrideDilation
;
num_reduce_
*=
Tildes
[
i
];
}
for
(
index_t
i_ztilde
=
0
;
i_ztilde
<
Tildes
[
0
];
++
i_ztilde
)
{
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
Tildes
[
1
];
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
Tildes
[
2
];
++
i_xtilde
)
{
// check slice is valid
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
window_lengths
[
0
]
-
i_ztilde
,
Tildes
[
0
]);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
window_lengths
[
1
]
-
i_ytilde
,
Tildes
[
1
]);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
window_lengths
[
2
]
-
i_xtilde
,
Tildes
[
2
]);
if
(
ZDotSlice
*
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
const
auto
dout_din_grid_desc
=
Make3DGridDescriptor_Out_M_K_In_M
(
dout_n_c_wos_lengths
,
din_n_c_wos_length
,
dout_n_c_wos_strides
,
din_n_c_wos_strides
,
window_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
dout_grid_desc_m_k_container_
.
push_back
(
dout_din_grid_desc
[
I0
]);
din_grid_desc_m_container_
.
push_back
(
dout_din_grid_desc
[
I1
]);
}
}
}
}
const
DOutDataType
*
p_dout_grid_
;
DInDataType
*
p_din_grid_
;
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths_
;
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length_
;
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides_
;
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides_
;
int
num_reduce_
;
std
::
vector
<
DoutGridDesc_M_K
>
dout_grid_desc_m_k_container_
;
std
::
vector
<
DinGridDesc_M
>
din_grid_desc_m_container_
;
Div
div_element_op_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
ave_time
=
0
;
for
(
index_t
i
=
0
;
i
<
arg
.
num_reduce_
;
i
++
)
{
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
false
,
false
,
false
,
// don't have index input
DOutDataType
,
DInDataType
,
ComputeDataType
,
int
,
DoutGridDesc_M_K
,
DinGridDesc_M
,
PassThrough
,
Div
>
;
ck
::
index_t
M
=
arg
.
dout_grid_desc_m_k_container_
[
i
].
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
dout_grid_desc_m_k_container_
[
i
],
arg
.
din_grid_desc_m_container_
[
i
],
PassThrough
{},
arg
.
div_element_op_
,
float
(
1
),
arg
.
p_dout_grid_
,
nullptr
,
float
(
0
),
arg
.
p_din_grid_
,
nullptr
);
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
constexpr
index_t
Rank
=
NDimSpatial
+
2
;
int
doutFastestDim
=
-
1
;
int
dinFastestDim
=
-
1
;
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
if
(
arg
.
dout_n_c_wos_strides_
[
i
]
==
1
)
doutFastestDim
=
i
;
if
(
arg
.
din_n_c_wos_strides_
[
i
]
==
1
)
dinFastestDim
=
i
;
}
if
(
doutFastestDim
==
-
1
||
dinFastestDim
==
-
1
)
{
if
constexpr
(
InSrcOutDstVectorSize
!=
1
)
return
false
;
}
else
{
if
(
arg
.
dout_n_c_wos_lengths_
[
doutFastestDim
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
arg
.
din_n_c_wos_length_
[
dinFastestDim
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
void
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_length
,
std
::
vector
<
ck
::
index_t
>
dout_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_c_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
override
{
constexpr
index_t
Rank
=
NDimSpatial
+
2
;
if
(
dout_n_c_wos_strides
.
size
()
!=
Rank
||
din_n_c_wos_strides
.
size
()
!=
Rank
||
dout_n_c_wos_lengths
.
size
()
!=
Rank
||
din_n_c_wos_length
.
size
()
!=
Rank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
window_lengths
.
size
()
!=
NDimSpatial
||
window_strides
.
size
()
!=
NDimSpatial
||
window_dilations
.
size
()
!=
NDimSpatial
||
input_left_pads
.
size
()
!=
NDimSpatial
||
input_right_pads
.
size
()
!=
NDimSpatial
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_n_c_wos_lengths
,
din_n_c_wos_length
,
dout_n_c_wos_strides
,
din_n_c_wos_strides
,
window_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceAvgPool3dBwd<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
4100d1d8
...
...
@@ -588,14 +588,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
LoopSched
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
...
...
@@ -840,9 +844,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
4100d1d8
...
...
@@ -378,13 +378,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
...
...
@@ -571,6 +574,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
4100d1d8
...
...
@@ -589,9 +589,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
4100d1d8
...
...
@@ -368,14 +368,18 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
LoopSched
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
...
...
@@ -580,9 +584,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
31
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment