Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8c4897d1
Unverified
Commit
8c4897d1
authored
Aug 30, 2023
by
Rostyslav Geyyer
Committed by
GitHub
Aug 30, 2023
Browse files
Merge branch 'develop' into lwpck-756
parents
9ba9ebec
9e86ebd6
Changes
542
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1839 additions
and
313 deletions
+1839
-313
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+2
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
...on/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
+213
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+751
-256
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+15
-1
include/ck/utility/amd_gemm_dpp.hpp
include/ck/utility/amd_gemm_dpp.hpp
+22
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-2
include/ck/utility/get_shift.hpp
include/ck/utility/get_shift.hpp
+20
-0
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+15
-7
include/ck/utility/inner_product_dpp8.hpp
include/ck/utility/inner_product_dpp8.hpp
+142
-0
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+72
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+16
-0
include/ck/utility/reduction_common.hpp
include/ck/utility/reduction_common.hpp
+0
-12
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+0
-18
include/ck/utility/workgroup_barrier.hpp
include/ck/utility/workgroup_barrier.hpp
+73
-0
include/ck/utility/workgroup_synchronization.hpp
include/ck/utility/workgroup_synchronization.hpp
+74
-0
include/ck/version.h.in
include/ck/version.h.in
+40
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
.../reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
+354
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+3
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_pool_fwd.hpp
+23
-10
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
8c4897d1
...
...
@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
)
;
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
8c4897d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace
ck
{
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
SrcResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
>
struct
ThreadwiseTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
// loop over space-filling curve
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
}
});
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
if
constexpr
(
num_access
==
0
)
{
return
typename
SpaceFillingCurve
::
Index
{};
}
else
{
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
private:
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
8c4897d1
...
...
@@ -18,32 +18,53 @@ template <
index_t
NDimSpatial
,
typename
ALayout
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
>
constexpr
auto
make_out_n_ho_wo_k_grid_desc
(
const
index_t
N
,
constexpr
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Do
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_strides
)
{
const
auto
KStride
=
Number
<
1
>
{};
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
{
const
index_t
NStride
=
out_g_n_k_wos_strides
[
1
];
const
index_t
HiStride
=
out_g_n_k_wos_strides
[
3
];
const
index_t
WiStride
=
out_g_n_k_wos_strides
[
4
];
const
auto
CStride
=
Number
<
1
>
{};
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
K
),
make_tuple
(
WiStride
,
C
Stride
));
make_tuple
(
WiStride
,
K
Stride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Ho
,
Wo
,
K
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
make_tuple
(
NStride
,
HiStride
,
WiStride
,
KStride
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
)
{
const
index_t
NStride
=
out_g_n_k_wos_strides
[
1
];
const
index_t
DoStride
=
out_g_n_k_wos_strides
[
3
];
const
index_t
HoStride
=
out_g_n_k_wos_strides
[
4
];
const
index_t
WoStride
=
out_g_n_k_wos_strides
[
5
];
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
WoStride
,
KStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
),
make_tuple
(
NStride
,
DoStride
,
HoStride
,
WoStride
,
KStride
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
...
...
@@ -60,12 +81,80 @@ make_out_n_ho_wo_k_grid_desc(const index_t N,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
)
{
// assume packed
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
));
}
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
ALayout
::
name
());
}
}
template
<
typename
BLayout
>
constexpr
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
)
{
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
}
else
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
,
Y
,
X
,
C
));
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
BLayout
::
name
());
}
}
template
<
index_t
NDimSpatial
,
typename
CLayout
>
constexpr
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_strides
)
{
if
constexpr
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
in_g_n_c_wis_strides
[
4
],
in_g_n_c_wis_strides
[
2
]));
}
else
if
constexpr
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
in_g_n_c_wis_strides
[
4
],
in_g_n_c_wis_strides
[
5
],
in_g_n_c_wis_strides
[
2
]));
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
CLayout
::
name
());
}
}
}
// namespace
template
<
...
...
@@ -82,10 +171,26 @@ struct TransformConvBwdDataToGemm_v1
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
static
constexpr
auto
DIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
HIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
WIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
YIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
XIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
),
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
...
...
@@ -100,44 +205,52 @@ struct TransformConvBwdDataToGemm_v1
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
index_t
i_ztilde
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
index_t
i_ytilde
=
tildes
[
YIdx
-
NonSpatialDimsNum
];
index_t
i_xtilde
=
tildes
[
XIdx
-
NonSpatialDimsNum
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Di
=
NDimSpatial
==
3
?
in_g_n_c_wis_lengths
[
DIdx
]
:
1
;
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
HIdx
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
WIdx
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
Do
=
NDimSpatial
==
3
?
out_g_n_k_wos_lengths
[
DIdx
]
:
1
;
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
HIdx
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
WIdx
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
Z
=
NDimSpatial
==
3
?
wei_g_k_c_xs_lengths
[
ZIdx
]
:
1
;
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
XIdx
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
InLeftPadD
=
input_left_pads
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
InLeftPadH
=
input_left_pads
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
InLeftPadW
=
input_left_pads
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
AK0
=
K
/
AK1
;
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
auto
out_n_ho_wo_k_grid_desc
=
make_out_n_ho_wo_k_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
N
,
Ho
,
Wo
,
K
,
out_g_n_k_wos_strides
);
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
out_g_n_k_wos_strides
);
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
AK0
=
math
::
integer_divide_ceil
(
K
,
AK1
);
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_
n_ho_wo_k_
grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
out_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
*
Do
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
...
...
@@ -152,41 +265,57 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
AK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
AK1
);
if
constexpr
(
NDimSpatial
==
2
)
{
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k
_grid_desc
,
out
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
...
...
@@ -206,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
=
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
...
...
@@ -214,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_
unmerge_transform
(
make_tuple
(
AK0
,
AK1
)
)),
make_
pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -226,29 +355,135 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
Sequence
<
5
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
AK0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemm
ak0
_gemmm_
gemmak1
_grid_desc
=
const
auto
out_gemm
k
_gemmm_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
make_tuple
(
-
ConvDilationD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
=
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
else
{
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
}
}
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>,
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
),
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
...
...
@@ -263,35 +498,40 @@ struct TransformConvBwdDataToGemm_v1
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
index_t
i_ztilde
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
index_t
i_ytilde
=
tildes
[
YIdx
-
NonSpatialDimsNum
];
index_t
i_xtilde
=
tildes
[
XIdx
-
NonSpatialDimsNum
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
Do
=
NDimSpatial
==
3
?
out_g_n_k_wos_lengths
[
DIdx
]
:
1
;
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
HIdx
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
WIdx
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
Z
=
NDimSpatial
==
3
?
wei_g_k_c_xs_lengths
[
ZIdx
]
:
1
;
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
XIdx
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
BK0
=
K
/
BK1
;
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
// assume packed
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
)
)
;
// k_y_x_c for 2d or k_z_y_x_c for 3d
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
BLayout
>
(
K
,
Z
,
Y
,
X
,
C
);
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
BK0
=
math
::
integer_divide_ceil
(
K
,
BK1
);
// B: weight tensor
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
...
...
@@ -299,7 +539,7 @@ struct TransformConvBwdDataToGemm_v1
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
...
...
@@ -311,23 +551,33 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
// GemmK is different for each GEMM
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
BK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
BK1
);
// B weight tensor
if
constexpr
(
NDimSpatial
==
2
)
{
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
wei_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
...
...
@@ -336,9 +586,9 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_
bk0_bk1
_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
BK0
,
BK1
)
),
const
auto
wei_
k
_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_
pass_through_transform
(
K
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
...
...
@@ -350,36 +600,125 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
,
2
,
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk_gemmn_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
auto
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
ZDot
,
ZTilde
),
make_tuple
(
ConvStrideD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ztilde
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
BK0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemm
bk0
_gemm
n_gemmbk1
_grid_desc
=
const
auto
wei_gemm
k
_gemm
_padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemm_padded_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
return
wei_gemmbk0_gemm_gemmbk1_grid_desc
;
}
else
{
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
}
}
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
),
bool
>::
type
=
false
>
static
auto
...
...
@@ -395,49 +734,57 @@ struct TransformConvBwdDataToGemm_v1
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
index_t
i_ztilde
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
index_t
i_ytilde
=
tildes
[
YIdx
-
NonSpatialDimsNum
];
index_t
i_xtilde
=
tildes
[
XIdx
-
NonSpatialDimsNum
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Di
=
NDimSpatial
==
3
?
in_g_n_c_wis_lengths
[
DIdx
]
:
1
;
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
HIdx
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
WIdx
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Do
=
NDimSpatial
==
3
?
out_g_n_k_wos_lengths
[
DIdx
]
:
1
;
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
HIdx
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
WIdx
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
Z
=
NDimSpatial
==
3
?
wei_g_k_c_xs_lengths
[
ZIdx
]
:
1
;
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
XIdx
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InLeftPadD
=
input_left_pads
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
InLeftPadH
=
input_left_pads
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
InLeftPadW
=
input_left_pads
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
InRightPadD
=
input_right_pads
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
InRightPadH
=
input_right_pads
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
InRightPadW
=
input_right_pads
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
// assume strided
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
in_g_n_c_wis_strides
[
4
],
in_g_n_c_wis_strides
[
2
]));
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDimSpatial
,
CLayout
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
in_g_n_c_wis_strides
);
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// C: input tensor
if
constexpr
(
NDimSpatial
==
2
)
{
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
...
...
@@ -453,7 +800,51 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
// C: input tensor
const
auto
in_n_x_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Do
),
make_tuple
(
I1
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_x_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
...
...
@@ -462,34 +853,51 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
}
}
else
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
// only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// C: input tensor
if
constexpr
(
NDimSpatial
==
2
)
{
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c
_grid_desc
,
in
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
...
...
@@ -497,7 +905,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
...
...
@@ -506,7 +915,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
...
...
@@ -536,13 +946,98 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
else
if
(
NDimSpatial
==
3
)
{
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZTilde
,
DTilde
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ztilde
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
else
{
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
}
}
}
// for input bias
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
8c4897d1
...
...
@@ -629,7 +629,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
...
...
@@ -682,6 +682,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
...
...
include/ck/utility/amd_gemm_dpp.hpp
0 → 100644
View file @
8c4897d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_gemm_dpp.hpp"
namespace
ck
{
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
constexpr
index_t
lane_group_size
=
8
;
__device__
index_t
get_lane_group_local_idx
()
{
return
threadIdx
.
x
/
lane_group_size
;
}
__device__
index_t
get_thread_idx_in_lane_group
()
{
return
threadIdx
.
x
%
lane_group_size
;
}
}
// namespace dpp8
}
// namespace ck
include/ck/utility/amd_xdlops.hpp
View file @
8c4897d1
...
...
@@ -364,7 +364,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__)
#if defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
reg_a
),
...
...
@@ -396,7 +396,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__)
#if defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
...
...
include/ck/utility/get_shift.hpp
0 → 100644
View file @
8c4897d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
template
<
index_t
N
>
static
constexpr
__device__
index_t
get_shift
()
{
return
(
get_shift
<
N
/
2
>
()
+
1
);
};
template
<
>
constexpr
__device__
index_t
get_shift
<
1
>
()
{
return
(
0
);
}
}
// namespace ck
include/ck/utility/inner_product.hpp
View file @
8c4897d1
...
...
@@ -13,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c);
template
<
>
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
#if CK_USE_AMD_
INNER_PRODUCT
_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
#if CK_USE_AMD_
V_MAC
_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
asm
volatile
(
"
\n
\
v_mac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#elif CK_USE_AMD_
INNER_PRODUCT
_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
#elif CK_USE_AMD_
V_MAC
_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
asm
volatile
(
"
\n
\
v_fmac_f32 %0, %1, %2
\n
\
"
...
...
@@ -76,22 +76,26 @@ template <>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#if CK_USE_AMD_V_DOT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
// ) s_nop with parameter 2 is equal to 3 x s_nop
asm
volatile
(
"
\n
\
v_dot2_f32_f16 %0, %1, %2, %0
\n
\
s_nop 2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_
s
dot2
(
a
,
b
,
c
,
false
);
c
=
__builtin_amdgcn_
f
dot2
(
a
,
b
,
c
,
false
);
#endif
#else
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
<
int32_
t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
<
int32_
t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
c
+=
type_convert
<
floa
t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
<
floa
t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
});
#endif
}
...
...
@@ -163,9 +167,13 @@ __device__ void
inner_product
<
int8x4_t
,
int8x4_t
,
int32_t
>
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
{
#if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#if CK_USE_AMD_V_DOT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
// ) s_nop with parameter 2 is equal to 3 x s_nop
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
s_nop 2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
bit_cast
<
int32_t
>
(
a
)),
"v"
(
bit_cast
<
int32_t
>
(
b
)),
"0"
(
c
));
...
...
include/ck/utility/inner_product_dpp8.hpp
0 → 100644
View file @
8c4897d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "amd_gemm_dpp.hpp"
#include "data_type.hpp"
#include "type_convert.hpp"
namespace
ck
{
namespace
dpp8
{
template
<
int
SrcLaneIdx
>
__device__
void
inline_v_dot2c_dpp8_instr
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
);
// clang-format off
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
0
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
1
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
2
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
3
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
4
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
5
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
6
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
template
<
>
__device__
void
inline_v_dot2c_dpp8_instr
<
7
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
){
asm
volatile
(
"
\n
v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
// clang-format on
/**
* Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly.
*/
template
<
int
SrcLaneIdx
,
bool
ShareA
>
__device__
void
inline_v_dot2c_dpp8
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
static_assert
(
SrcLaneIdx
>=
0
&&
SrcLaneIdx
<
dpp8
::
lane_group_size
,
"DPP8 src broadcast lane out of range <0, 7>."
);
if
constexpr
(
ShareA
)
{
inline_v_dot2c_dpp8_instr
<
SrcLaneIdx
>
(
a
,
b
,
c
);
}
else
{
inline_v_dot2c_dpp8_instr
<
SrcLaneIdx
>
(
b
,
a
,
c
);
}
}
/**
* DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast
* patters.
*/
constexpr
std
::
array
<
int
,
dpp8
::
lane_group_size
>
IntrinsicMaskDpp8
=
{
0
,
// 0, 0, 0, 0, 0, 0, 0, 0
2396745
,
// 1, 1, 1, 1, 1, 1, 1, 1
4793490
,
// 2, 2, 2, 2, 2, 2, 2, 2
7190235
,
// 3, 3, 3, 3, 3, 3, 3, 3
9586980
,
// 4, 4, 4, 4, 4, 4, 4, 4
11983725
,
// 5, 5, 5, 5, 5, 5, 5, 5
14380470
,
// 6, 6, 6, 6, 6, 6, 6, 6
16777215
,
// 7, 7, 7, 7, 7, 7, 7, 7
};
/**
* Returns DPP8 sel modifier as an integer required for the intrinsic instruction.
*/
template
<
int
SrcLaneIdx
>
constexpr
int
get_dpp_sel_mask_broadcast
()
{
static_assert
(
SrcLaneIdx
>=
0
&&
SrcLaneIdx
<
dpp8
::
lane_group_size
,
"DPP8 src broadcast lane out of range <0, 7>."
);
return
IntrinsicMaskDpp8
[
SrcLaneIdx
];
}
template
<
int
SrcLaneIdx
>
__device__
void
intrinsic_fdot2_impl
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
constexpr
int
sel_mask
=
get_dpp_sel_mask_broadcast
<
SrcLaneIdx
>
();
const
half2_t
val_from_other_lane
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_mov_dpp8
(
bit_cast
<
int
>
(
a
),
sel_mask
));
c
=
__builtin_amdgcn_fdot2
(
val_from_other_lane
,
b
,
c
,
false
);
}
/**
* Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics.
*/
template
<
int
SrcLaneIdx
,
bool
ShareA
>
__device__
void
intrinsic_fdot2
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
if
constexpr
(
ShareA
)
{
intrinsic_fdot2_impl
<
SrcLaneIdx
>
(
a
,
b
,
c
);
}
else
{
intrinsic_fdot2_impl
<
SrcLaneIdx
>
(
b
,
a
,
c
);
}
}
/**
* Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier.
*
* DPP modifier allows us to share one of the vectors between lanes in a lane group.
* When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same
* lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared.
* Note that all the threads in a lane group uses the same vector - broadcast pattern.
*
* `SrcLaneIdx` must be in range from 0 to 7.
*/
template
<
typename
TA
,
typename
TB
,
typename
TC
,
int
SrcLaneIdx
,
bool
ShareA
>
__device__
void
inner_product_dpp
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
)
{
#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM
inline_v_dot2c_dpp8
<
SrcLaneIdx
,
ShareA
>
(
a
,
b
,
c
);
#else
intrinsic_fdot2
<
SrcLaneIdx
,
ShareA
>
(
a
,
b
,
c
);
#endif
}
}
// namespace dpp8
}
// namespace ck
include/ck/utility/magic_division.hpp
View file @
8c4897d1
...
...
@@ -157,4 +157,76 @@ struct MagicDivision
}
};
struct
MDiv
{
// 1 dword -> 3 dword storage
uint32_t
divisor
;
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
__host__
__device__
MDiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
{
auto
tmp
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
multiplier
=
tmp
[
Number
<
0
>
{}];
shift
=
tmp
[
Number
<
1
>
{}];
}
__host__
__device__
MDiv
()
:
divisor
(
0
),
multiplier
(
0
),
shift
(
0
)
{}
__host__
__device__
void
update
(
uint32_t
divisor_
)
{
divisor
=
divisor_
;
auto
tmp
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
multiplier
=
tmp
[
Number
<
0
>
{}];
shift
=
tmp
[
Number
<
1
>
{}];
}
__host__
__device__
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
MagicDivision
::
DoMagicDivision
(
dividend_
,
multiplier
,
shift
);
}
__host__
__device__
void
divmod
(
uint32_t
dividend_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor
);
}
__host__
__device__
uint32_t
get
()
const
{
return
divisor
;
}
};
struct
MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
__host__
__device__
MDiv2
(
uint32_t
divisor_
)
{
auto
tmp
=
MagicDivision
::
CalculateMagicNumbers
(
divisor_
);
multiplier
=
tmp
[
Number
<
0
>
{}];
shift
=
tmp
[
Number
<
1
>
{}];
}
__host__
__device__
MDiv2
()
:
multiplier
(
0
),
shift
(
0
)
{}
__host__
__device__
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
MagicDivision
::
DoMagicDivision
(
dividend_
,
multiplier
,
shift
);
}
__host__
__device__
void
divmod
(
uint32_t
dividend_
,
uint32_t
divisor_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor_
);
}
};
}
// namespace ck
include/ck/utility/math.hpp
View file @
8c4897d1
...
...
@@ -240,5 +240,21 @@ struct less
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
<
y
;
}
};
template
<
index_t
X
>
__host__
__device__
constexpr
auto
next_power_of_two
()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr
index_t
Y
=
1
<<
(
32
-
__builtin_clz
(
X
-
1
));
return
Y
;
}
template
<
index_t
X
>
__host__
__device__
constexpr
auto
next_power_of_two
(
Number
<
X
>
x
)
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr
index_t
Y
=
1
<<
(
32
-
__builtin_clz
(
x
.
value
-
1
));
return
Number
<
Y
>
{};
}
}
// namespace math
}
// namespace ck
include/ck/utility/reduction_common.hpp
View file @
8c4897d1
...
...
@@ -25,16 +25,4 @@ struct float_equal_zero
};
};
template
<
index_t
N
>
static
constexpr
__device__
index_t
get_shift
()
{
return
(
get_shift
<
N
/
2
>
()
+
1
);
};
template
<
>
constexpr
__device__
index_t
get_shift
<
1
>
()
{
return
(
0
);
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
8c4897d1
...
...
@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int32 via fp32
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
// convert int32 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int8 via fp32
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
...
...
include/ck/utility/workgroup_barrier.hpp
0 → 100644
View file @
8c4897d1
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
namespace
ck
{
struct
workgroup_barrier
{
__device__
workgroup_barrier
(
uint32_t
*
ptr
)
:
base_ptr
(
ptr
)
{}
__device__
uint32_t
ld
(
uint32_t
offset
)
{
#if 0
float d = llvm_amdgcn_raw_buffer_load_fp32(
amdgcn_make_buffer_resource(base_ptr),
0,
offset,
AMDGCN_BUFFER_GLC);
union cvt {
float f32;
uint32_t u32;
};
cvt x;
x.f32 = d;
return x.u32;
#endif
return
__atomic_load_n
(
base_ptr
+
offset
,
__ATOMIC_RELAXED
);
}
__device__
void
wait_eq
(
uint32_t
offset
,
uint32_t
value
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
ld
(
offset
)
!=
value
)
{}
}
__syncthreads
();
}
__device__
void
wait_lt
(
uint32_t
offset
,
uint32_t
value
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
ld
(
offset
)
<
value
)
{}
}
__syncthreads
();
}
__device__
void
wait_set
(
uint32_t
offset
,
uint32_t
compare
,
uint32_t
value
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
atomicCAS
(
base_ptr
+
offset
,
compare
,
value
)
!=
compare
)
{}
}
__syncthreads
();
}
// enter critical zoon, assume buffer is zero when launch kernel
__device__
void
aquire
(
uint32_t
offset
)
{
wait_set
(
offset
,
0
,
1
);
}
// exit critical zoon, assume buffer is zero when launch kernel
__device__
void
release
(
uint32_t
offset
)
{
wait_set
(
offset
,
1
,
0
);
}
__device__
void
inc
(
uint32_t
offset
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicAdd
(
base_ptr
+
offset
,
1
);
}
}
uint32_t
*
base_ptr
;
};
}
// namespace ck
include/ck/utility/workgroup_synchronization.hpp
0 → 100644
View file @
8c4897d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
// Initialization flag of Barrier object, can be any value except for zero
static
constexpr
int
BarrierInitFlag
=
0x7856
;
// 1) only the first thread-block in the synchronizaton group is supposed to call this function. It
// is the responsibility of the user to ensure the two integer values in p_control_bits are zeros
// before calling gms_init().
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
// repetitious initialization of p_control_bits buffer is required
static
__device__
void
gms_init
(
int
NumWarps
,
int
*
p_control_bits
)
{
union
{
int
two32
[
2
];
unsigned
long
one64
;
}
regs
;
regs
.
two32
[
0
]
=
BarrierInitFlag
;
regs
.
two32
[
1
]
=
NumWarps
;
if
(
threadIdx
.
x
==
0
)
atomicCAS
(
reinterpret_cast
<
unsigned
long
*>
(
p_control_bits
),
0
,
regs
.
one64
);
};
// all the workgroups in the synchronization group is supposed to call this function
static
__device__
void
gms_barrier
(
int
*
p_control_bits
)
{
constexpr
int
mask
=
warpSize
-
1
;
if
((
threadIdx
.
x
&
mask
)
==
0
)
{
// ensure the barrier object is initialized
do
{
const
int
r0
=
__atomic_load_n
(
&
p_control_bits
[
0
],
__ATOMIC_RELAXED
);
if
(
r0
==
BarrierInitFlag
)
break
;
}
while
(
true
);
// go ahead toward the barrier line
atomicSub
(
&
p_control_bits
[
1
],
1
);
// wait until all warps have arrived
do
{
const
int
r1
=
__atomic_load_n
(
&
p_control_bits
[
1
],
__ATOMIC_RELAXED
);
if
(
r1
==
0
)
break
;
}
while
(
true
);
};
};
// 1) Only the first thread-block in the synchronizaton group is supposed to call this function.
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
// repetitious initialization of p_control_bits buffer is required
static
__device__
void
gms_reset
(
int
*
p_control_bits
)
{
// reset the barrier object
if
(
threadIdx
.
x
==
0
)
(
void
)
atomicCAS
(
&
p_control_bits
[
0
],
BarrierInitFlag
,
0
);
};
}
// namespace ck
include/ck/version.h.in
0 → 100644
View file @
8c4897d1
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
/* the configured version and settings for miopen- Composable Kernel */
#ifndef CK_VERSION_H_
#define CK_VERSION_H_
// clang-format off
#define CK_VERSION @CMAKE_PROJECT_VERSION@
#define CK_VERSION_MAJOR @CMAKE_PROJECT_VERSION_MAJOR@
#define CK_VERSION_MINOR @CMAKE_PROJECT_VERSION_MINOR@
#define CK_VERSION_PATCH @CMAKE_PROJECT_VERSION_PATCH@
#define CK_COMMIT_ID @COMMIT_ID@
// clang-format on
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
0 → 100644
View file @
8c4897d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// dinput descriptor in [N, C, Do, Ho, Wo] order
// doutput descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
DInDataType
,
typename
DOutDataType
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceAvgPoolBwd
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
Tensor
<
DInDataType
>&
dinput
,
const
Tensor
<
DOutDataType
>&
doutput
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
:
dinput_
{
dinput
},
doutput_
{
doutput
},
window_spatial_lengths_
{
window_spatial_lengths
},
window_strides_
{
window_strides
},
window_dilations_
{
window_dilations
},
in_left_pads_
{
dinput_left_pads
},
in_right_pads_
{
dinput_right_pads
}
{
}
Tensor
<
DInDataType
>&
dinput_
;
const
Tensor
<
DOutDataType
>&
doutput_
;
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths_
;
std
::
vector
<
index_t
>
window_strides_
;
std
::
vector
<
index_t
>
window_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceAvgPoolBwd
::
Argument
;
template
<
ck
::
index_t
NDimSpatial_
,
typename
std
::
enable_if
<
NDimSpatial_
==
1
,
bool
>
::
type
=
false
>
float
RunAvgPoolBwd
(
const
Argument
&
arg
)
{
// Let input = x, outpu = y
// shape of x = [10], y = [6]
// window_size = 5, pad = 0, stride = 1, dilation = 1
// Forward:
// y0 = 1/5 * (x0 + x1 + x2 + x3 + x4)
// y1 = 1/5 * (x1 + x2 + x3 + x4 + x5)
// ...
// y5 = 1/5 * (x5 + x6 + x7 + x8 + x9)
// y6 = 1/5 * (x6 + x7 + x8 + x9)
// ...
// y9 = 1/5 * (x9)
// Backward:
// shape of dy = [6], dx = [10]
// dx0 = 1/5 * dy0
// dx1 = 1/5 * (dy0 + dy1)
// dx2 = 1/5 * (dy0 + dy1 + dy2)
// ...
// dx4 = 1/5 * (dy0 + dy1 + dy2 + dy3 + dy4)
// dx5 = 1/5 * (dy1 + dy2 + dy3 + dy4 + dy5)
// ...
// dx9 = 1/5 * (dy5 + dy6 + dy7 + dy8 + dy9)
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
2
];
float
v_acc
=
0
;
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
0
]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if
(
w_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
wo
));
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
X
);
arg
.
dinput_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
template
<
ck
::
index_t
NDimSpatial_
,
typename
std
::
enable_if
<
NDimSpatial_
==
2
,
bool
>
::
type
=
false
>
float
RunAvgPoolBwd
(
const
Argument
&
arg
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
Ho
=
arg
.
doutput_
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
3
];
float
v_acc
=
0
;
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window_dilations_
[
0
]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if
(
h_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
{
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
window_strides_
[
1
]
==
0
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
ho
,
wo
));
}
}
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
],
arg
.
dinput_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
template
<
ck
::
index_t
NDimSpatial_
,
typename
std
::
enable_if
<
NDimSpatial_
==
3
,
bool
>
::
type
=
false
>
float
RunAvgPoolBwd
(
const
Argument
&
arg
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
Z
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
2
];
std
::
size_t
Do
=
arg
.
doutput_
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
doutput_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
for
(
std
::
size_t
z
=
0
;
z
<
Z
;
++
z
)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto
d_tmp
=
static_cast
<
ck
::
long_index_t
>
(
di
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
window_dilations_
[
0
]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if
(
d_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
{
auto
do_
=
static_cast
<
ck
::
long_index_t
>
(
d_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
{
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
window_strides_
[
1
]
==
0
)
{
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
2
]);
if
(
w_tmp
%
arg
.
window_strides_
[
2
]
==
0
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
2
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
do_
,
ho
,
wo
));
}
}
}
}
}
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
Z
*
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
],
arg
.
dinput_
.
GetLengths
()[
3
],
arg
.
dinput_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
Argument
&
arg
)
{
if
(
!
(
arg
.
dinput_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
&&
arg
.
doutput_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
))
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
return
RunAvgPoolBwd
<
NDimSpatial
>
(
arg
);
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
Tensor
<
DInDataType
>&
dinput
,
const
Tensor
<
DOutDataType
>&
doutput
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
{
if
(
window_spatial_lengths
.
size
()
!=
NDimSpatial
||
window_strides
.
size
()
!=
NDimSpatial
||
window_dilations
.
size
()
!=
NDimSpatial
||
dinput_left_pads
.
size
()
!=
NDimSpatial
||
dinput_right_pads
.
size
()
!=
NDimSpatial
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
Argument
{
dinput
,
doutput
,
window_spatial_lengths
,
window_strides
,
window_dilations
,
dinput_left_pads
,
dinput_right_pads
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceAvgPoolBwd"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
8c4897d1
...
...
@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_
acc
);
arg
.
input_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_
in
);
};
make_ParallelTensorFunctor
(
f_ncw
,
...
...
@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_
acc
);
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_
in
);
};
make_ParallelTensorFunctor
(
f_nchw
,
...
...
@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_
acc
);
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_
in
);
};
make_ParallelTensorFunctor
(
f_ncdhw
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
8c4897d1
...
...
@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
Acc
DataType
v_c
;
C
DataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
arg
.
c_m_n_
(
m
,
n
)
=
v_c
;
};
make_ParallelTensorFunctor
(
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp
View file @
8c4897d1
...
...
@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor
<
IndexDataType
>&
out_indices
,
const
std
::
vector
<
ck
::
index_t
>&
window_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
/*in_right_pads*/
)
:
in_
(
in
),
...
...
@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices_
(
out_indices
),
window_spatial_lengths_
(
window_spatial_lengths
),
window_strides_
(
window_strides
),
window_dilations_
(
window_dilations
),
in_left_pads_
(
in_left_pads
),
reduceLength_
(
1
)
{
...
...
@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor
<
IndexDataType
>&
out_indices_
;
const
std
::
vector
<
ck
::
index_t
>&
window_spatial_lengths_
;
const
std
::
vector
<
ck
::
index_t
>&
window_strides_
;
const
std
::
vector
<
ck
::
index_t
>&
window_dilations_
;
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads_
;
int
reduceLength_
;
};
...
...
@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for
(
ck
::
index_t
z
=
0
;
z
<
arg
.
window_spatial_lengths_
[
0
];
++
z
)
{
ck
::
index_t
di
=
do_
*
arg
.
window_strides_
[
0
]
+
z
-
arg
.
in_left_pads_
[
0
];
ck
::
index_t
di
=
do_
*
arg
.
window_strides_
[
0
]
+
z
*
arg
.
window_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
ck
::
index_t
y
=
0
;
y
<
arg
.
window_spatial_lengths_
[
1
];
++
y
)
{
ck
::
index_t
hi
=
ho
*
arg
.
window_strides_
[
1
]
+
y
-
arg
.
in_left_pads_
[
1
];
ck
::
index_t
hi
=
ho
*
arg
.
window_strides_
[
1
]
+
y
*
arg
.
window_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
for
(
ck
::
index_t
x
=
0
;
x
<
arg
.
window_spatial_lengths_
[
2
];
++
x
)
{
ck
::
index_t
wi
=
wo
*
arg
.
window_strides_
[
2
]
+
x
-
arg
.
in_left_pads_
[
2
];
ck
::
index_t
wi
=
wo
*
arg
.
window_strides_
[
2
]
+
x
*
arg
.
window_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
if
(
di
>=
0
&&
di
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
2
])
&&
hi
>=
0
&&
...
...
@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for
(
ck
::
index_t
z
=
0
;
z
<
arg
.
window_spatial_lengths_
[
0
];
++
z
)
{
ck
::
index_t
di
=
do_
*
arg
.
window_strides_
[
0
]
+
z
-
arg
.
in_left_pads_
[
0
];
ck
::
index_t
di
=
do_
*
arg
.
window_strides_
[
0
]
+
z
*
arg
.
window_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
ck
::
index_t
y
=
0
;
y
<
arg
.
window_spatial_lengths_
[
1
];
++
y
)
{
ck
::
index_t
hi
=
ho
*
arg
.
window_strides_
[
1
]
+
y
-
arg
.
in_left_pads_
[
1
];
ck
::
index_t
hi
=
ho
*
arg
.
window_strides_
[
1
]
+
y
*
arg
.
window_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
for
(
ck
::
index_t
x
=
0
;
x
<
arg
.
window_spatial_lengths_
[
2
];
++
x
)
{
ck
::
index_t
wi
=
wo
*
arg
.
window_strides_
[
2
]
+
x
-
arg
.
in_left_pads_
[
2
];
ck
::
index_t
wi
=
wo
*
arg
.
window_strides_
[
2
]
+
x
*
arg
.
window_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
if
(
di
>=
0
&&
di
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
2
])
&&
hi
>=
0
&&
...
...
@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for
(
ck
::
index_t
y
=
0
;
y
<
arg
.
window_spatial_lengths_
[
0
];
++
y
)
{
ck
::
index_t
hi
=
ho
*
arg
.
window_strides_
[
0
]
+
y
-
arg
.
in_left_pads_
[
0
];
ck
::
index_t
hi
=
ho
*
arg
.
window_strides_
[
0
]
+
y
*
arg
.
window_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
ck
::
index_t
x
=
0
;
x
<
arg
.
window_spatial_lengths_
[
1
];
++
x
)
{
ck
::
index_t
wi
=
wo
*
arg
.
window_strides_
[
1
]
+
x
-
arg
.
in_left_pads_
[
1
];
ck
::
index_t
wi
=
wo
*
arg
.
window_strides_
[
1
]
+
x
*
arg
.
window_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
if
(
hi
>=
0
&&
hi
<
static_cast
<
ck
::
index_t
>
(
arg
.
in_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
...
...
@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor
<
IndexDataType
>&
out_indices
,
const
std
::
vector
<
ck
::
index_t
>&
window_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
in_right_pads
)
{
...
...
@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices
,
window_spatial_lengths
,
window_strides
,
window_dilations
,
in_left_pads
,
in_right_pads
};
}
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment