Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
81c942cd
Unverified
Commit
81c942cd
authored
Jul 08, 2021
by
Chao Liu
Committed by
GitHub
Jul 08, 2021
Browse files
Deprecate static kernel (#42)
* deprecate static kernels
parent
b8b2d0a6
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
51 additions
and
3429 deletions
+51
-3429
composable_kernel/include/tensor_operation/threadwise_direct_convolution.hpp
...nclude/tensor_operation/threadwise_direct_convolution.hpp
+0
-228
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+0
-165
composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp
...include/tensor_operation/threadwise_generic_tensor_op.hpp
+0
-20
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+0
-191
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+0
-1
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+0
-1042
composable_kernel/include/utility/amd_dlop.hpp
composable_kernel/include/utility/amd_dlop.hpp
+42
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+1
-1
composable_kernel/include/utility/config.nvidia.hpp.in
composable_kernel/include/utility/config.nvidia.hpp.in
+0
-54
composable_kernel/include/utility/float_type.nvidia.hpp.in
composable_kernel/include/utility/float_type.nvidia.hpp.in
+0
-180
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
...ble_kernel/include/utility/in_memory_operation.amd.hpp.in
+0
-241
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
..._kernel/include/utility/in_memory_operation.nvidia.hpp.in
+0
-109
composable_kernel/include/utility/synchronization.nvidia.hpp.in
...able_kernel/include/utility/synchronization.nvidia.hpp.in
+0
-13
composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.cpp
...d_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.cpp
+0
-8
composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
+0
-7
composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.cpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.cpp
+0
-8
driver/conv_bwd_data_driver.cpp
driver/conv_bwd_data_driver.cpp
+0
-299
driver/conv_driver.cpp
driver/conv_driver.cpp
+0
-780
driver/conv_driver_v2.cpp
driver/conv_driver_v2.cpp
+8
-8
driver/include/conv_common.hpp
driver/include/conv_common.hpp
+0
-74
No files found.
composable_kernel/include/tensor_operation/threadwise_direct_convolution.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_1
(
InDesc
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
TOut
*
__restrict__
p_out
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 0
if(blockIdx.x == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: ");
}
#endif
for
(
index_t
n
=
0
;
n
<
out_desc
.
GetLength
(
I0
);
++
n
)
{
for
(
index_t
k
=
0
;
k
<
out_desc
.
GetLength
(
I1
);
++
k
)
{
for
(
index_t
ho
=
0
;
ho
<
out_desc
.
GetLength
(
I2
);
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
out_desc
.
GetLength
(
I3
);
++
wo
)
{
for
(
index_t
c
=
0
;
c
<
wei_desc
.
GetLength
(
I1
);
++
c
)
{
for
(
index_t
y
=
0
;
y
<
wei_desc
.
GetLength
(
I2
);
++
y
)
{
for
(
index_t
x
=
0
;
x
<
wei_desc
.
GetLength
(
I3
);
++
x
)
{
const
index_t
hi
=
ho
+
y
;
const
index_t
wi
=
wo
+
x
;
const
index_t
in_index
=
in_desc
.
GetOffsetFromMultiIndex
(
n
,
c
,
hi
,
wi
);
const
index_t
wei_index
=
wei_desc
.
GetOffsetFromMultiIndex
(
k
,
c
,
y
,
x
);
const
index_t
out_index
=
out_desc
.
GetOffsetFromMultiIndex
(
n
,
k
,
ho
,
wo
);
fused_multiply_accumulate
(
p_out
[
out_index
],
p_wei
[
wei_index
],
p_in
[
in_index
]);
}
}
}
}
}
}
}
}
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
template
<
class
TInWei
,
class
TOut
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_2
(
InDesc
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
TOut
*
__restrict__
p_out
)
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
in_reg_desc
=
make_ConstantTensorDescriptor_packed
(
in_desc
.
GetLengths
());
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor_packed
(
wei_desc
.
GetLengths
());
// register
TInWei
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
TInWei
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
// copy input tensor into register
threadwise_tensor_slice_copy
(
in_desc
,
p_in
,
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
(),
Number
<
1
>
{});
// copy input tensor into register
threadwise_tensor_slice_copy
(
wei_desc
,
p_wei
,
wei_reg_desc
,
p_wei_reg
,
wei_reg_desc
.
GetLengths
(),
Number
<
1
>
{});
// do convolution
threadwise_direct_convolution_1
(
in_reg_desc
,
p_in_reg
,
wei_reg_desc
,
p_wei_reg
,
out_desc
,
p_out
);
}
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
template
<
class
Data
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_3
(
InDesc
,
Data
*
const
__restrict__
p_in
,
WeiDesc
,
Data
*
const
__restrict__
p_wei
,
OutDesc
,
Data
*
__restrict__
p_out
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
in_reg_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
in_desc
.
GetLength
(
I0
),
in_desc
.
GetLength
(
I1
),
out_desc
.
GetLength
(
I2
),
out_desc
.
GetLength
(
I3
)
>
{});
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
wei_desc
.
GetLength
(
I0
),
wei_desc
.
GetLength
(
I1
),
1
,
1
>
{});
Data
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
Data
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
constexpr
index_t
in_w_new_read
=
1
;
constexpr
auto
in_desc_reg_new_read
=
make_ConstantTensorDescriptor
(
Sequence
<
in_reg_desc
.
GetLength
(
I0
),
in_reg_desc
.
GetLength
(
I1
),
in_reg_desc
.
GetLength
(
I2
),
in_w_new_read
>
{});
#if 0
// this verison reused old input data in register, and read new data from LDS
// loop over vertical direction
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// read first input
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// do first 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// shift old input to the left
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
// read new input
threadwise_4d_tensor_copy(
in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
in_reg_desc,
p_in_reg +
in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
in_desc_reg_new_read.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
}
#elif
1
// this version read all input from LDS when filter moves
// loop over vertical direction
for
(
index_t
y
=
0
;
y
<
wei_desc
.
GetLength
(
I2
);
++
y
)
{
// loop over horizontal direction
for
(
index_t
x
=
0
;
x
<
wei_desc
.
GetLength
(
I3
);
++
x
)
{
// read new weight
threadwise_4d_tensor_copy
(
wei_desc
,
p_wei
+
wei_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
y
,
x
),
wei_reg_desc
,
p_wei_reg
,
wei_reg_desc
.
GetLengths
());
// read new input
threadwise_4d_tensor_copy
(
in_desc
,
p_in
+
in_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
y
,
x
),
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
());
// do 1x1 conv
threadwise_direct_convolution_1
(
in_reg_desc
,
p_in_reg
,
wei_reg_desc
,
p_wei_reg
,
out_desc
,
p_out
);
}
}
#endif
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_THREADWISE_GEMM_HPP
#define CK_THREADWISE_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
namespace
ck
{
template
<
typename
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
for
(
index_t
i
=
0
;
i
<
Matrix
::
NRow
();
++
i
)
{
for
(
index_t
j
=
0
;
j
<
Matrix
::
NCol
();
++
j
)
{
const
index_t
id
=
Matrix
::
CalculateOffset
(
i
,
j
);
p_thread
[
id
]
=
Float
(
0
);
}
}
}
template
<
typename
SrcMatrix
,
typename
DstMatrix
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy
{
__device__
constexpr
ThreadwiseMatrixSliceCopy
()
{
static_assert
(
SrcMatrix
::
RowStride
()
%
DataPerAccess
==
0
&&
DstMatrix
::
RowStride
()
%
DataPerAccess
==
0
,
"wrong! wrong alignment"
);
static_assert
(
NSliceCol
%
DataPerAccess
==
0
,
"wrong! should be NSliceCol % DataPerAccess == 0"
);
}
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
for
(
index_t
i
=
0
;
i
<
NSliceRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NSliceCol
;
j
+=
DataPerAccess
)
{
const
index_t
src_index
=
SrcMatrix
::
CalculateOffset
(
i
,
j
);
const
index_t
dst_index
=
DstMatrix
::
CalculateOffset
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}
};
// C += transpose(A) * B
// Element of matrix can be vectorized data
template
<
typename
MatrixA
,
typename
MatrixB
,
typename
MatrixC
>
struct
ThreadwiseGemmTransANormalBNormalC
{
__device__
constexpr
ThreadwiseGemmTransANormalBNormalC
()
{
static_assert
(
MatrixA
::
NRow
()
==
MatrixB
::
NRow
()
&&
MatrixA
::
NCol
()
==
MatrixC
::
NRow
()
&&
MatrixB
::
NCol
()
==
MatrixC
::
NCol
(),
"wrong!"
);
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
index_t
M
=
MatrixC
::
NRow
();
constexpr
index_t
N
=
MatrixC
::
NCol
();
constexpr
index_t
K
=
MatrixA
::
NRow
();
// A is transposed
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
m
=
0
;
m
<
M
;
++
m
)
{
for
(
index_t
n
=
0
;
n
<
N
;
++
n
)
{
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
const
index_t
bindex
=
MatrixB
::
CalculateOffset
(
k
,
n
);
const
index_t
cindex
=
MatrixC
::
CalculateOffset
(
m
,
n
);
p_c
[
cindex
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
aindex
],
p_b
[
bindex
]);
}
}
}
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
index_t
M
=
MatrixC
::
NRow
();
constexpr
index_t
N
=
MatrixC
::
NCol
();
constexpr
index_t
K
=
MatrixA
::
NRow
();
// A is transposed
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
m
=
0
;
m
<
M
;
++
m
)
{
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
static_if
<
N
==
2
>
{}([
&
](
auto
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
amd_assembly_outer_product_1x2
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_c
[
cindex_0
],
p_c
[
cindex_1
]);
});
static_if
<
N
==
4
>
{}([
&
](
auto
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
const
index_t
bindex_2
=
MatrixB
::
CalculateOffset
(
k
,
2
);
const
index_t
bindex_3
=
MatrixB
::
CalculateOffset
(
k
,
3
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
const
index_t
cindex_2
=
MatrixC
::
CalculateOffset
(
m
,
2
);
const
index_t
cindex_3
=
MatrixC
::
CalculateOffset
(
m
,
3
);
amd_assembly_outer_product_1x4
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_b
[
bindex_2
],
p_b
[
bindex_3
],
p_c
[
cindex_0
],
p_c
[
cindex_1
],
p_c
[
cindex_2
],
p_c
[
cindex_3
]);
});
}
}
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr
bool
has_amd_asm
=
is_same
<
FloatC
,
float
>
{}
&&
((
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{})
||
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
static_if
<
has_amd_asm
>
{}([
&
](
auto
fwd
)
{
Run_amd_asm
(
p_a
,
p_b
,
fwd
(
p_c
));
})
.
Else
([
&
](
auto
)
{
Run_source
(
p_a
,
p_b
,
p_c
);
});
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace
ck
{
template
<
class
Float
,
class
TDesc
>
__device__
void
threadwise_generic_tensor_set_zero
(
TDesc
,
Float
*
__restrict__
p
)
{
static_ford
<
decltype
(
TDesc
::
GetLengths
())
>
{}([
&
](
auto
multi_id
)
{
constexpr
index_t
offset
=
TDesc
::
GetOffsetFromMultiIndex
(
multi_id
);
p
[
offset
]
=
static_cast
<
Float
>
(
0
);
});
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp"
namespace
ck
{
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template
<
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SrcDstDimAccessOrder
,
index_t
SrcDstVectorReadWriteDim
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
Generic
,
AddressSpace
DstAddressSpace
=
AddressSpace
::
Generic
,
InMemoryDataOperation
DstInMemOp
=
InMemoryDataOperation
::
Set
,
index_t
SrcDataStride
=
1
,
index_t
DstDataStride
=
1
>
struct
ThreadwiseGenericTensorSliceCopy_v4r2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
typename
TensorCoordinate
<
SrcDesc
>::
type
;
using
DstCoord
=
typename
TensorCoordinate
<
DstDesc
>::
type
;
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v4r2
(
const
Index
&
src_slice_origin
,
const
Index
&
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
Size
()
&&
nDim
==
SrcDstDimAccessOrder
::
Size
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
SrcDstDimAccessOrder
>
{},
"wrong! map is not valid"
);
static_assert
(
SliceLengths
{}[
SrcDstVectorReadWriteDim
]
%
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
==
0
,
"wrong! cannot evenly divide"
);
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v4r2
()
:
ThreadwiseGenericTensorSliceCopy_v4r2
(
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
())
{
}
__device__
void
SetSrcSliceOrigin
(
SrcCoord
src_slice_origin
)
{
mSrcSliceOrigin
=
src_slice_origin
;
}
__device__
void
SetDstSliceOrigin
(
DstCoord
dst_slice_origin
)
{
mDstSliceOrigin
=
dst_slice_origin
;
}
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
SrcDstVectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
>
{};
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
ford
<
decltype
(
long_vector_access_lengths
),
SrcDstDimAccessOrder
>
{}(
[
&
](
auto
long_vector_access_id
)
{
// data id w.r.t slicing-window
auto
long_vector_data_begin_id
=
long_vector_access_id
;
long_vector_data_begin_id
(
vector_access_dim
)
=
long_vector_size
*
long_vector_access_id
[
vector_access_dim
];
// buffer to hold a src long-vector
SrcData
p_src_long_vector
[
long_vector_size
];
// load data from src to the long-vector buffer
static_for
<
0
,
long_vector_size
/
src_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
auto
scalar_id
=
make_zero_multi_index
<
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
src_data_per_access
;
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
const
auto
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data
<
SrcData
,
SrcDataPerRead
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
SrcDataStride
,
1
>
(
p_src
,
src_coord
.
GetOffset
(),
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
(),
SrcDesc
::
GetElementSpace
(),
p_src_long_vector
,
buffer_offset
,
true
,
long_vector_size
);
});
// SrcData to DstData conversion
DstData
p_dst_long_vector
[
long_vector_size
];
static_for
<
0
,
long_vector_size
,
1
>
{}([
&
](
auto
i
)
{
p_dst_long_vector
[
i
]
=
type_convert
<
DstData
>
{}(
p_src_long_vector
[
i
]);
});
// store data from the long-vector buffer to dst
static_for
<
0
,
long_vector_size
/
dst_data_per_access
,
1
>
{}([
&
](
auto
i
)
{
auto
scalar_id
=
make_zero_multi_index
<
nDim
>
();
scalar_id
(
vector_access_dim
)
=
i
*
dst_data_per_access
;
const
index_t
buffer_offset
=
i
*
dst_data_per_access
;
const
auto
dst_coord
=
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data
<
DstData
,
DstDataPerWrite
,
AddressSpace
::
Vgpr
,
DstAddressSpace
,
DstInMemOp
,
1
,
DstDataStride
>
(
p_dst_long_vector
,
buffer_offset
,
true
,
long_vector_size
,
p_dst
,
dst_coord
.
GetOffset
(),
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
(),
DstDesc
::
GetElementSpace
());
});
});
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
const
T
&
step_sizes_
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
const
auto
step_sizes
=
to_multi_index
(
step_sizes_
);
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mSrcSliceOrigin
+=
to_multi_index
(
step_sizes
);
})
.
Else
([
&
](
auto
)
{
mSrcSliceOrigin
-=
step_sizes
;
});
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSliceWindow
(
const
T
&
step_sizes_
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
const
auto
step_sizes
=
to_multi_index
(
step_sizes_
);
static_if
<
PositiveDirection
>
{}([
&
](
auto
)
{
mDstSliceOrigin
+=
step_sizes
;
})
.
Else
([
&
](
auto
)
{
mDstSliceOrigin
-=
step_sizes
;
});
}
private:
SrcCoord
mSrcSliceOrigin
;
DstCoord
mDstSliceOrigin
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
81c942cd
...
...
@@ -2,7 +2,6 @@
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
#include "amd_xdlops.hpp"
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
#define CK_AMD_BUFFER_ADDRESSING_HPP
#include "float_type.hpp"
#include "amd_buffer_addressing_v2.hpp"
namespace
ck
{
template
<
typename
T
>
union
BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t
data
;
T
*
address
[
2
];
int32_t
range
[
4
];
int32_t
config
[
4
];
};
__device__
float
__llvm_amdgcn_buffer_load_f32
(
int32x4_t
srsrc
,
index_t
vindex
,
index_t
offset
,
bool
glc
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.load.f32"
);
__device__
float2_t
__llvm_amdgcn_buffer_load_f32x2
(
int32x4_t
srsrc
,
index_t
vindex
,
index_t
offset
,
bool
glc
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.load.v2f32"
);
__device__
float4_t
__llvm_amdgcn_buffer_load_f32x4
(
int32x4_t
srsrc
,
index_t
vindex
,
index_t
offset
,
bool
glc
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.load.v4f32"
);
__device__
half_t
__llvm_amdgcn_raw_buffer_load_f16
(
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.f16"
);
__device__
ushort
__llvm_amdgcn_raw_buffer_load_bf16
(
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.bf16"
);
__device__
void
__llvm_amdgcn_buffer_store_f32
(
float
vdata
,
int32x4_t
srsrc
,
index_t
vindex
,
index_t
offset
,
bool
glc
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.store.f32"
);
__device__
void
__llvm_amdgcn_buffer_store_f32x2
(
float2_t
vdata
,
int32x4_t
srsrc
,
index_t
vindex
,
index_t
offset
,
bool
glc
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.store.v2f32"
);
__device__
void
__llvm_amdgcn_buffer_store_f32x4
(
float4_t
vdata
,
int32x4_t
srsrc
,
index_t
vindex
,
index_t
offset
,
bool
glc
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.store.v4f32"
);
__device__
void
__llvm_amdgcn_raw_buffer_store_f16
(
half_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.f16"
);
__device__
void
__llvm_amdgcn_raw_buffer_store_bf16
(
ushort
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.bf16"
);
#if CK_USE_AMD_BUFFER_ATOMIC_FADD
#if CK_HIP_VERSION_FLAT >= 3010020405
// starting ROCm-3.10, the return type becomes float
__device__
float
#else
__device__
void
#endif
__llvm_amdgcn_buffer_atomic_add_f32
(
float
vdata
,
int32x4_t
rsrc
,
index_t
vindex
,
index_t
offset
,
bool
slc
)
__asm
(
"llvm.amdgcn.buffer.atomic.fadd.f32"
);
#endif
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
VectorSize
>
__device__
typename
vector_type
<
T
,
VectorSize
>::
type
amd_buffer_load
(
const
T
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_elemenst_space
);
// buffer_store requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
VectorSize
>
__device__
void
amd_buffer_store
(
const
T
*
p_src_thread
,
T
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
);
// buffer_atomic requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
VectorSize
>
__device__
void
amd_buffer_atomic_add
(
const
T
*
p_src_thread
,
T
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
);
template
<
>
__device__
float
amd_buffer_load
<
float
,
1
>
(
const
float
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
float
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
float
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
return
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
#else
float
tmp
=
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
tmp
:
float
(
0
);
#endif
}
template
<
>
__device__
float2_t
amd_buffer_load
<
float
,
2
>
(
const
float
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
float
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
float
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
return
__llvm_amdgcn_buffer_load_f32x2
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
#else
float2_t
tmp
=
__llvm_amdgcn_buffer_load_f32x2
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
tmp
:
float2_t
(
0
);
#endif
}
template
<
>
__device__
float4_t
amd_buffer_load
<
float
,
4
>
(
const
float
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
float
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
float
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
return
__llvm_amdgcn_buffer_load_f32x4
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
#else
float4_t
tmp
=
__llvm_amdgcn_buffer_load_f32x4
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
tmp
:
float4_t
(
0
);
#endif
}
template
<
>
__device__
half_t
amd_buffer_load
<
half_t
,
1
>
(
const
half_t
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
half_t
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
half_t
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
half_t
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return
__llvm_amdgcn_raw_buffer_load_f16
(
src_wave_buffer_resource
.
data
,
src_addr_shift
+
src_thread_addr_offset
,
0
,
0
);
#else
half_t
zero
(
0
);
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return
src_thread_data_valid
?
__llvm_amdgcn_raw_buffer_load_f16
(
src_wave_buffer_resource
.
data
,
src_thread_addr_offset
,
0
,
0
)
:
zero
;
#endif
}
template
<
>
__device__
half2_t
amd_buffer_load
<
half_t
,
2
>
(
const
half_t
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
half_t
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
half_t
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
half_t
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
float
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
return
*
reinterpret_cast
<
half2_t
*>
(
&
dst_out_tmp
);
#else
half2_t
zeros
(
0
);
float
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
*
reinterpret_cast
<
half2_t
*>
(
&
dst_out_tmp
)
:
zeros
;
#endif
}
template
<
>
__device__
half4_t
amd_buffer_load
<
half_t
,
4
>
(
const
half_t
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
half_t
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
half_t
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
half_t
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
float2_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x2
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
return
*
reinterpret_cast
<
half4_t
*>
(
&
dst_out_tmp
);
#else
half4_t
zeros
(
0
);
float2_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x2
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
*
reinterpret_cast
<
half4_t
*>
(
&
dst_out_tmp
)
:
zeros
;
#endif
}
template
<
>
__device__
half8_t
amd_buffer_load
<
half_t
,
8
>
(
const
half_t
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
half_t
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
half_t
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
half_t
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
float4_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x4
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
return
*
reinterpret_cast
<
half8_t
*>
(
&
dst_out_tmp
);
#else
half8_t
zeros
(
0
);
float4_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x4
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
*
reinterpret_cast
<
half8_t
*>
(
&
dst_out_tmp
)
:
zeros
;
#endif
}
template
<
>
__device__
ushort
amd_buffer_load
<
ushort
,
1
>
(
const
ushort
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
ushort
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
ushort
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
ushort
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return
__llvm_amdgcn_raw_buffer_load_bf16
(
src_wave_buffer_resource
.
data
,
src_addr_shift
+
src_thread_addr_offset
,
0
,
0
);
#else
ushort
zero
(
0
);
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return
src_thread_data_valid
?
__llvm_amdgcn_raw_buffer_load_bf16
(
src_wave_buffer_resource
.
data
,
src_thread_addr_offset
,
0
,
0
)
:
zero
;
#endif
}
template
<
>
__device__
ushort2_t
amd_buffer_load
<
ushort
,
2
>
(
const
ushort
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
ushort
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
ushort
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
ushort
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
float
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
return
*
reinterpret_cast
<
ushort2_t
*>
(
&
dst_out_tmp
);
#else
ushort2_t
zeros
(
0
);
float
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
*
reinterpret_cast
<
ushort2_t
*>
(
&
dst_out_tmp
)
:
zeros
;
#endif
}
template
<
>
__device__
ushort4_t
amd_buffer_load
<
ushort
,
4
>
(
const
ushort
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
ushort
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
ushort
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
ushort
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
float2_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x2
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
return
*
reinterpret_cast
<
ushort4_t
*>
(
&
dst_out_tmp
);
#else
ushort4_t
zeros
(
0
);
float2_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x2
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
*
reinterpret_cast
<
ushort4_t
*>
(
&
dst_out_tmp
)
:
zeros
;
#endif
}
template
<
>
__device__
ushort8_t
amd_buffer_load
<
ushort
,
8
>
(
const
ushort
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_data_range
)
{
BufferResource
<
ushort
>
src_wave_buffer_resource
;
// wavewise base address (64 bit)
src_wave_buffer_resource
.
address
[
0
]
=
const_cast
<
ushort
*>
(
p_src_wave
);
// wavewise range (32 bit)
src_wave_buffer_resource
.
range
[
2
]
=
src_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
src_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
ushort
);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
float4_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x4
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_shift
+
src_thread_addr_offset
,
false
,
false
);
return
*
reinterpret_cast
<
ushort8_t
*>
(
&
dst_out_tmp
);
#else
ushort8_t
zeros
(
0
);
float4_t
dst_out_tmp
=
__llvm_amdgcn_buffer_load_f32x4
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
);
return
src_thread_data_valid
?
*
reinterpret_cast
<
ushort8_t
*>
(
&
dst_out_tmp
)
:
zeros
;
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
float
,
1
>
(
const
float
*
p_src_thread
,
float
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
float
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
float
,
2
>
(
const
float
*
p_src_thread
,
float
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
float
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32x2
(
*
reinterpret_cast
<
const
float2_t
*>
(
p_src_thread
),
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32x2
(
*
reinterpret_cast
<
const
float2_t
*>
(
p_src_thread
),
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
float
,
4
>
(
const
float
*
p_src_thread
,
float
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
float
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32x4
(
*
reinterpret_cast
<
const
float4_t
*>
(
p_src_thread
),
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32x4
(
*
reinterpret_cast
<
const
float4_t
*>
(
p_src_thread
),
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
half_t
,
1
>
(
const
half_t
*
p_src_thread
,
half_t
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
half_t
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
half_t
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
__llvm_amdgcn_raw_buffer_store_f16
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
,
0
);
#else
if
(
dst_thread_data_valid
)
{
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
__llvm_amdgcn_raw_buffer_store_f16
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
dst_thread_addr_offset
,
0
,
0
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
half_t
,
2
>
(
const
half_t
*
p_src_thread
,
half_t
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
half_t
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
half_t
);
const
float
*
p_src_tmp
=
reinterpret_cast
<
const
float
*>
(
p_src_thread
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
half_t
,
4
>
(
const
half_t
*
p_src_thread
,
half_t
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
half_t
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
half_t
);
const
float2_t
*
p_src_tmp
=
reinterpret_cast
<
const
float2_t
*>
(
p_src_thread
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32x2
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32x2
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
half_t
,
8
>
(
const
half_t
*
p_src_thread
,
half_t
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
half_t
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
half_t
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
half_t
);
const
float4_t
*
p_src_tmp
=
reinterpret_cast
<
const
float4_t
*>
(
p_src_thread
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32x4
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32x4
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
ushort
,
1
>
(
const
ushort
*
p_src_thread
,
ushort
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
ushort
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
ushort
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_raw_buffer_store_bf16
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
,
0
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_raw_buffer_store_bf16
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
dst_thread_addr_offset
,
0
,
0
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
ushort
,
2
>
(
const
ushort
*
p_src_thread
,
ushort
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
ushort
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
ushort
);
const
float
*
p_src_tmp
=
reinterpret_cast
<
const
float
*>
(
p_src_thread
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
ushort
,
4
>
(
const
ushort
*
p_src_thread
,
ushort
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
ushort
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
ushort
);
const
float2_t
*
p_src_tmp
=
reinterpret_cast
<
const
float2_t
*>
(
p_src_thread
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32x2
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32x2
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_store
<
ushort
,
8
>
(
const
ushort
*
p_src_thread
,
ushort
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
ushort
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
ushort
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
ushort
);
const
float4_t
*
p_src_tmp
=
reinterpret_cast
<
const
float4_t
*>
(
p_src_thread
);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_store_f32x4
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32x4
(
*
p_src_tmp
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
#if CK_USE_AMD_BUFFER_ATOMIC_FADD
template
<
>
__device__
void
amd_buffer_atomic_add
<
float
,
1
>
(
const
float
*
p_src_thread
,
float
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
float
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
__llvm_amdgcn_buffer_atomic_add_f32
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
,
false
);
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_atomic_add_f32
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
);
}
#endif
}
template
<
>
__device__
void
amd_buffer_atomic_add
<
float
,
2
>
(
const
float
*
p_src_thread
,
float
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
float
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
;
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
for
(
index_t
i
=
0
;
i
<
2
;
++
i
)
{
__llvm_amdgcn_buffer_atomic_add_f32
(
p_src_thread
[
i
],
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
+
i
*
sizeof
(
float
),
false
);
}
#else
if
(
dst_thread_data_valid
)
{
for
(
index_t
i
=
0
;
i
<
2
;
++
i
)
{
__llvm_amdgcn_buffer_atomic_add_f32
(
p_src_thread
[
i
],
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
+
i
*
sizeof
(
float
),
false
);
}
}
#endif
}
template
<
>
__device__
void
amd_buffer_atomic_add
<
float
,
4
>
(
const
float
*
p_src_thread
,
float
*
p_dst_wave
,
index_t
dst_thread_data_offset
,
bool
dst_thread_data_valid
,
index_t
dst_data_range
)
{
BufferResource
<
float
>
dst_wave_buffer_resource
;
// wavewise base address (64 bit)
dst_wave_buffer_resource
.
address
[
0
]
=
p_dst_wave
;
// wavewise range (32 bit)
dst_wave_buffer_resource
.
range
[
2
]
=
dst_data_range
*
sizeof
(
float
);
// wavewise setting (32 bit)
dst_wave_buffer_resource
.
config
[
3
]
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
for
(
index_t
i
=
0
;
i
<
4
;
++
i
)
{
__llvm_amdgcn_buffer_atomic_add_f32
(
p_src_thread
[
i
],
dst_wave_buffer_resource
.
data
,
0
,
dst_addr_shift
+
dst_thread_addr_offset
+
i
*
sizeof
(
float
),
false
);
}
#else
if
(
dst_thread_data_valid
)
{
for
(
index_t
i
=
0
;
i
<
4
;
++
i
)
{
__llvm_amdgcn_buffer_atomic_add_f32
(
p_src_thread
[
i
],
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
+
i
*
sizeof
(
float
),
false
);
}
}
#endif
}
#endif // CK_USE_AMD_BUFFER_ATOMIC_FADD
}
// namespace ck
#endif
composable_kernel/include/utility/amd_dlop.hpp
View file @
81c942cd
...
...
@@ -23,6 +23,48 @@ amd_inner_product_dlop<float, float, float>(const float& a, const float& b, floa
#endif
}
template
<
>
__device__
void
amd_inner_product_dlop
<
float2_t
,
float2_t
,
float
>
(
const
float2_t
&
a
,
const
float2_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_inner_product_dlop
(
vector_type
<
float
,
2
>
{
a
}.
AsType
<
float
>
()[
I0
],
vector_type
<
float
,
2
>
{
b
}.
AsType
<
float
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
2
>
{
a
}.
AsType
<
float
>
()[
I1
],
vector_type
<
float
,
2
>
{
b
}.
AsType
<
float
>
()[
I1
],
c
);
}
template
<
>
__device__
void
amd_inner_product_dlop
<
float4_t
,
float4_t
,
float
>
(
const
float4_t
&
a
,
const
float4_t
&
b
,
float
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I0
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I0
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I1
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I1
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I2
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I2
],
c
);
amd_inner_product_dlop
(
vector_type
<
float
,
4
>
{
a
}.
AsType
<
float
>
()[
I3
],
vector_type
<
float
,
4
>
{
b
}.
AsType
<
float
>
()[
I3
],
c
);
}
#if CK_USE_AMD_DLOP
template
<
>
__device__
void
...
...
composable_kernel/include/utility/common_header.hpp
View file @
81c942cd
...
...
@@ -13,7 +13,6 @@
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#include "in_memory_operation.hpp"
#include "integral_constant.hpp"
#include "math.hpp"
#include "number.hpp"
...
...
@@ -25,6 +24,7 @@
#include "type.hpp"
#include "utility.hpp"
#include "magic_division.hpp"
#include "amd_buffer_addressing_v2.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
...
...
composable_kernel/include/utility/config.nvidia.hpp.in
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_CONFIG_NVIDIA_HPP
#define CK_CONFIG_NVIDIA_HPP
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <nvToolsExt.h>
// index type: unsigned or signed
#define CK_UNSIGNED_INDEX_TYPE 0
// device backend
#define CK_DEVICE_BACKEND_NVIDIA 1
// disable AMD inline asm and intrinsic
#define CK_USE_AMD_INLINE_ASM 0
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_BUFFER_ADDRESSING 0
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0
#define CK_USE_AMD_XDLOPS 0
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#define CK_USE_AMD_XDLOPS_EMULATE 0
// experimental implementation
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
#define CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
enum AddressSpace
{
Generic,
Global,
Lds,
Vgpr
};
enum InMemoryDataOperation
{
Set,
AtomicAdd
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
} // namespace ck
#endif
composable_kernel/include/utility/float_type.nvidia.hpp.in
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_FLOAT_TYPE_NVIDIA_HPP
#define CK_FLOAT_TYPE_NVIDIA_HPP
#include "number.hpp"
namespace ck {
// For some reason, CUDA need this definition, otherwise
// compiler won't generate optimal load and store instruction, and
// kernel would produce wrong result, indicating the compiler fail to generate correct
// instruction,
// float
using float2_t = float2;
using float4_t = float4;
// float
typedef float float32_t __attribute__((ext_vector_type(32)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// fp16
using half_t = half;
using half2_t = half2;
using half4_t = float2;
template <class T, index_t N>
struct vector_type
{
typedef struct
{
T scalar[N];
} type;
};
template <>
struct vector_type<float, 1>
{
using type = float;
template <index_t I>
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
};
template <>
struct vector_type<float, 2>
{
using type = float2_t;
union DataType
{
type vector;
float scalar[2];
};
template <index_t I>
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
__host__ __device__ static type Pack(float s0, float s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<float, 4>
{
using type = float4_t;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
template <index_t I>
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 4, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
};
template <>
struct vector_type<half_t, 1>
{
using type = half_t;
template <index_t I>
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<half_t*>(&v) + I) = s;
}
};
template <>
struct vector_type<half_t, 2>
{
using type = half2_t;
union DataType
{
type vector;
half_t scalar[2];
};
template <index_t I>
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<half_t*>(&v) + I) = s;
}
__host__ __device__ static type Pack(half_t s0, half_t s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
template <typename T>
struct inner_product_with_conversion
{
static constexpr auto convert = type_convert<T>();
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
__device__ T operator()(half2_t a, half2_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(half4_t a, half4_t b) const
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
};
} // namespace ck
#endif
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_IN_MEMORY_OPERATION_AMD_HPP
#define CK_IN_MEMORY_OPERATION_AMD_HPP
#include "float_type.hpp"
#if CK_USE_AMD_BUFFER_ADDRESSING
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing_v2.hpp"
#endif
namespace ck {
template <typename T>
__device__ void atomic_add_impl(T* p_dst, T src)
{
atomicAdd(p_dst, src);
}
// atomicAdd for float does not support vector type
template <>
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 2; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <>
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 4; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
// This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t /* dst_range */) const
{
if(dst_valid)
{
if(src_valid)
{
#if 0
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
#else
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[0x3fffffff & src_offset]);
#endif
}
else
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = 0;
}
}
}
#if CK_USE_AMD_BUFFER_ADDRESSING
// buffer_load requires:
// 1) p_src_thread must be in global memory space, p_dst_thread must be vgpr
// 2) p_src_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <>
__device__ void Run<AddressSpace::Global, AddressSpace::Vgpr>(const T* p_src,
index_t src_offset,
bool src_valid,
index_t src_range,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t /* dst_range */) const
{
if(dst_valid)
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
amd_buffer_load_v2<T, DataPerAccess>(p_src, src_offset, src_valid, src_range);
}
}
// buffer_store requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <>
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t dst_range) const
{
const auto zeros = vector_t(0);
amd_buffer_store_v2<T, DataPerAccess>(
src_valid ? *reinterpret_cast<const vector_t*>(&(p_src[src_offset])) : zeros,
p_dst,
dst_offset,
dst_valid,
dst_range);
}
#endif
};
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
// This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t /* dst_range */) const
{
if(src_valid && dst_valid)
{
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}
}
#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_FADD
// buffer_atomic requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
// 2) p_dst_thread to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <>
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
index_t src_offset,
bool src_valid,
index_t /* src_range */,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t dst_range) const
{
const auto zeros = vector_t(0);
amd_buffer_atomic_add<T, DataPerAccess>(
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range);
}
#endif
};
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
__device__ void transfer_data(const T* p_src,
index_t src_offset,
bool src_valid,
index_t src_range,
T* p_dst,
index_t dst_offset,
bool dst_valid,
index_t dst_range)
{
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if constexpr(SrcDataStride == 1 && DstDataStride == 1)
{
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
{
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}
}
else
{
#pragma unroll
for(index_t i = 0; i < DataPerAccess; ++i)
{
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
src_offset + i * SrcDataStride,
src_valid,
src_range,
p_dst,
dst_offset + i * DstDataStride,
dst_valid,
dst_range);
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
{
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
src_offset + i * SrcDataStride,
src_valid,
src_range,
p_dst,
dst_offset + i * DstDataStride,
dst_valid,
dst_range);
}
}
}
}
} // namespace ck
#endif
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_IN_MEMORY_OPERATION_NVIDIA_HPP
#define CK_IN_MEMORY_OPERATION_NVIDIA_HPP
namespace ck {
template <typename T>
__device__ void atomic_add_impl(T* p_dst, T src)
{
atomicAdd(p_dst, src);
}
// atomicAdd for float does not support vector type
template <>
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 2; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <>
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
{
float* p_dst_float = reinterpret_cast<float*>(p_dst);
const float* p_src_float = reinterpret_cast<const float*>(&src);
for(index_t i = 0; i < 4; ++i)
{
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
}
}
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
{
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
}
};
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::type;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
{
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}
};
template <typename T,
index_t DataPerAccess,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if(SrcDataStride == 1 && DstDataStride == 1)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, p_dst, dst_offset);
});
}
else
{
for(index_t i = 0; i < DataPerAccess; i++)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
});
}
}
}
} // namespace ck
#endif
composable_kernel/include/utility/synchronization.nvidia.hpp.in
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP
#define CK_SYNCHRONIZATION_NVIDIA_HPP
#include "config.hpp"
namespace ck {
__device__ void block_sync_lds() { __syncthreads(); }
__device__ void block_sync_lds_vmem() { __syncthreads(); }
} // namespace ck
#endif
composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.cpp
deleted
100644 → 0
View file @
b8b2d0a6
extern
"C"
__global__
void
gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
(
const
void
*
const
__restrict__
p_in_global
,
const
void
*
const
__restrict__
p_wei_global
,
void
*
const
__restrict__
p_out_global
){
};
composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
deleted
100644 → 0
View file @
b8b2d0a6
extern
"C"
__global__
void
gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
const
void
*
const
__restrict__
p_in_global
,
const
void
*
const
__restrict__
p_wei_global
,
void
*
const
__restrict__
p_out_global
){
};
composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.cpp
deleted
100644 → 0
View file @
b8b2d0a6
extern
"C"
__global__
void
gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
(
const
void
*
const
__restrict__
p_in_global
,
const
void
*
const
__restrict__
p_wei_global
,
void
*
const
__restrict__
p_out_global
){
};
driver/conv_bwd_data_driver.cpp
deleted
100644 → 0
View file @
b8b2d0a6
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
launcher
;
#if 1
// 1x1 filter, 14x14 image
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
1
;
constexpr
index_t
WI
=
128
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1 filter, 8x8 image
constexpr
index_t
N
=
256
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 7x7 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 1x1 filter, 14x14 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 5x5 filter, 2x2 pad, 7x7 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
1280
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#endif
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
constexpr
auto
out_nkhw_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcyx_desc
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{});
ostream_tensor_descriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_tensor_descriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_tensor_descriptor
(
out_nkhw_desc
,
std
::
cout
<<
"out_nkhw_desc: "
);
print_array
(
"LeftPads"
,
LeftPads
{});
print_array
(
"LeftPads"
,
LeftPads
{});
print_array
(
"RightPads"
,
RightPads
{});
print_array
(
"ConvStrides"
,
ConvStrides
{});
print_array
(
"ConvDilations"
,
ConvDilations
{});
Tensor
<
float
>
in_nchw_device
(
make_HostTensorDescriptor
(
in_nchw_desc
));
Tensor
<
float
>
in_nchw_host
(
make_HostTensorDescriptor
(
in_nchw_desc
));
Tensor
<
float
>
wei_kcyx
(
make_HostTensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
float
>
out_nkhw
(
make_HostTensorDescriptor
(
out_nkhw_desc
));
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
if
(
argc
!=
3
)
{
printf
(
"arg1: do_verification, arg2: nrepeat
\n
"
);
exit
(
1
);
}
bool
do_verification
=
atoi
(
argv
[
1
]);
std
::
size_t
nrepeat
=
atoi
(
argv
[
2
]);
if
(
do_verification
)
{
#if 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#endif
}
#if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif
0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
#endif
(
in_nchw_desc
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
if
(
do_verification
)
{
host_direct_convolution_backward_data
(
in_nchw_host
,
wei_kcyx
,
out_nkhw
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{});
check_error
(
in_nchw_host
,
in_nchw_device
);
#if 0
LogRange(std::cout << "out_nkhw : ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx : ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_host : ", in_nchw_host.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_device : ", in_nchw_device.mData, ",") << std::endl;
#endif
}
}
driver/conv_driver.cpp
deleted
100644 → 0
View file @
b8b2d0a6
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
if
(
argc
!=
5
)
{
printf
(
"arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat
\n
"
);
exit
(
1
);
}
const
bool
do_verification
=
atoi
(
argv
[
1
]);
const
bool
do_log
=
atoi
(
argv
[
2
]);
const
int
init_method
=
atoi
(
argv
[
3
]);
const
int
nrepeat
=
atoi
(
argv
[
4
]);
#if 0
constexpr index_t N = 256;
constexpr index_t C = 256;
constexpr index_t HI = 16;
constexpr index_t WI = 16;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using InLeftPads = Sequence<0, 0>;
using InRightPads = Sequence<0, 0>;
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
Hi
=
540
;
constexpr
index_t
Wi
=
960
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
Hi
=
270
;
constexpr
index_t
Wi
=
480
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
Hi
=
1080
;
constexpr
index_t
Wi
=
1920
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
1
;
constexpr
index_t
Hi
=
1024
;
constexpr
index_t
Wi
=
2048
;
constexpr
index_t
K
=
4
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
Hi
=
540
;
constexpr
index_t
Wi
=
960
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
Hi
=
270
;
constexpr
index_t
Wi
=
480
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 36x36, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
Hi
=
37
;
constexpr
index_t
Wi
=
37
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
Hi
=
35
;
constexpr
index_t
Wi
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
Hi
=
8
;
constexpr
index_t
Wi
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
160
;
constexpr
index_t
Hi
=
73
;
constexpr
index_t
Wi
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
96
;
constexpr
index_t
Hi
=
35
;
constexpr
index_t
Wi
=
35
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
Hi
=
71
;
constexpr
index_t
Wi
=
71
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
Hi
=
17
;
constexpr
index_t
Wi
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
3
,
0
>
;
using
InRightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
Hi
=
17
;
constexpr
index_t
Wi
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
3
>
;
using
InRightPads
=
Sequence
<
0
,
3
>
;
#elif 0
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
Hi
=
299
;
constexpr
index_t
Wi
=
299
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
Hi
=
147
;
constexpr
index_t
Wi
=
147
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 149x149
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
Hi
=
149
;
constexpr
index_t
Wi
=
149
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 17x17, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
Hi
=
17
;
constexpr
index_t
Wi
=
17
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
Hi
=
35
;
constexpr
index_t
Wi
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
constexpr
index_t
Hi
=
35
;
constexpr
index_t
Wi
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x3, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
Hi
=
8
;
constexpr
index_t
Wi
=
8
;
constexpr
index_t
K
=
448
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
1
>
;
using
InRightPads
=
Sequence
<
0
,
1
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
Hi
=
8
;
constexpr
index_t
Wi
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
0
>
;
using
InRightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
Hi
=
147
;
constexpr
index_t
Wi
=
147
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
Hi
=
73
;
constexpr
index_t
Wi
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
3
,
0
>
;
using
InRightPads
=
Sequence
<
3
,
0
>
;
#elif 0
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
Hi
=
73
;
constexpr
index_t
Wi
=
73
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
256
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
Hi
=
14
;
constexpr
index_t
Wi
=
14
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14
constexpr
index_t
N
=
256
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
Hi
=
14
;
constexpr
index_t
Wi
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
Hi
=
14
;
constexpr
index_t
Wi
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
Hi
=
28
;
constexpr
index_t
Wi
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
Hi
=
14
;
constexpr
index_t
Wi
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
Hi
=
56
;
constexpr
index_t
Wi
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x7, 230x230 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
Hi
=
230
;
constexpr
index_t
Wi
=
230
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride = 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
Hi
=
28
;
constexpr
index_t
Wi
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
Hi
=
28
;
constexpr
index_t
Wi
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
Hi
=
7
;
constexpr
index_t
Wi
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
Hi
=
7
;
constexpr
index_t
Wi
=
7
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
Hi
=
56
;
constexpr
index_t
Wi
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
0
,
0
>
;
using
InRightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
Hi
=
56
;
constexpr
index_t
Wi
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
InLeftPads
=
Sequence
<
1
,
1
>
;
using
InRightPads
=
Sequence
<
1
,
1
>
;
#endif
constexpr
index_t
YEff
=
(
Y
-
1
)
*
ConvDilations
{}[
0
]
+
1
;
constexpr
index_t
XEff
=
(
X
-
1
)
*
ConvDilations
{}[
1
]
+
1
;
constexpr
index_t
Ho
=
(
Hi
+
InLeftPads
{}[
0
]
+
InRightPads
{}[
0
]
-
YEff
)
/
ConvStrides
{}[
0
]
+
1
;
constexpr
index_t
Wo
=
(
Wi
+
InLeftPads
{}[
1
]
+
InRightPads
{}[
1
]
-
XEff
)
/
ConvStrides
{}[
1
]
+
1
;
#if 1
constexpr
index_t
in_vector_size
=
1
;
using
in_data_t
=
typename
vector_type
<
float
,
in_vector_size
>::
type
;
using
acc_data_t
=
float
;
using
out_data_t
=
float
;
#elif 1
using
in_data_t
=
half_t
;
constexpr
index_t
in_vector_size
=
1
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
#elif 0
constexpr
index_t
in_vector_size
=
1
;
using
in_data_t
=
typename
vector_type
<
float
,
in_vector_size
>::
type
;
using
acc_data_t
=
float
;
using
out_data_t
=
int8_t
;
#elif 1
constexpr
index_t
in_vector_size
=
16
;
using
in_data_t
=
typename
vector_type
<
int8_t
,
in_vector_size
>::
type
;
using
acc_data_t
=
int32_t
;
using
out_data_t
=
int8_t
;
#endif
Tensor
<
in_data_t
>
in_nchw
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
C
,
Hi
,
Wi
}));
Tensor
<
in_data_t
>
wei_kcyx
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
K
,
C
,
Y
,
X
}));
Tensor
<
out_data_t
>
out_nkhw_host
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
K
,
Ho
,
Wo
}));
Tensor
<
out_data_t
>
out_nkhw_device
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
K
,
Ho
,
Wo
}));
ostream_HostTensorDescriptor
(
in_nchw
.
mDesc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_HostTensorDescriptor
(
wei_kcyx
.
mDesc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_HostTensorDescriptor
(
out_nkhw_host
.
mDesc
,
std
::
cout
<<
"out_nkhw_desc: "
);
print_array
(
"InLeftPads"
,
InLeftPads
{});
print_array
(
"InRightPads"
,
InRightPads
{});
print_array
(
"ConvStrides"
,
ConvStrides
{});
print_array
(
"ConvDilations"
,
ConvDilations
{});
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
if
(
do_verification
)
{
switch
(
init_method
)
{
case
0
:
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
1
:
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
2
:
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
3
:
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
default:
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
auto
gen_wei
=
[](
auto
...
is
)
{
return
GeneratorTensor_2
{
1
,
5
}(
is
...)
*
GeneratorTensor_Checkboard
{}(
is
...);
};
wei_kcyx
.
GenerateTensorValue
(
gen_wei
,
num_thread
);
}
}
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C
,
Hi
,
Wi
>
{});
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K
,
Ho
,
Wo
>
{});
#if 1
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
InLeftPads
{},
InRightPads
{},
nrepeat
);
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
InLeftPads
{},
InRightPads
{},
nrepeat
);
#elif 0
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
InLeftPads
{},
InRightPads
{},
nrepeat
);
#endif
if
(
do_verification
)
{
host_direct_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
ConvStrides
{},
ConvDilations
{},
InLeftPads
{},
InRightPads
{});
check_error
(
out_nkhw_host
,
out_nkhw_device
);
if
(
do_log
)
{
LogRange
(
std
::
cout
<<
"in_nchw : "
,
in_nchw
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"wei_kcyx: "
,
wei_kcyx
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"out_nkhw_host : "
,
out_nkhw_host
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"out_nkhw_device: "
,
out_nkhw_device
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
driver/conv_driver_v2.cpp
View file @
81c942cd
...
...
@@ -24,16 +24,16 @@
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW
0
#define USE_CONV_FWD_V4R4_NHWC
0
#define USE_CONV_FWD_V4R4R2_NHWC
0
#define USE_CONV_FWD_V4R5_NCHW
0
#define USE_CONV_FWD_V4R4_NCHW
1
#define USE_CONV_FWD_V4R4_NHWC
1
#define USE_CONV_FWD_V4R4R2_NHWC
1
#define USE_CONV_FWD_V4R5_NCHW
1
#define USE_CONV_FWD_V4R5R2_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW
0
#define USE_CONV_FWD_V4R4R2_XDL_NHWC
0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC
0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC
0
#define USE_CONV_FWD_V4R4_XDL_NCHW
1
#define USE_CONV_FWD_V4R4R2_XDL_NHWC
1
#define USE_CONV_FWD_V4R4R3_XDL_NHWC
1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC
1
enum
ConvForwardAlgo
{
...
...
driver/include/conv_common.hpp
View file @
81c942cd
#ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP
#include "tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
enum
ConvTensorLayout
...
...
@@ -13,53 +12,6 @@ enum ConvTensorLayout
NHWCc
};
template
<
class
InDesc
,
class
WeiDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
LeftPads
,
class
RightPads
>
constexpr
auto
get_convolution_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
)
{
using
namespace
ck
;
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
in_desc
.
GetNumOfDimension
()
==
4
,
"input nDim is not 4"
);
static_assert
(
wei_desc
.
GetNumOfDimension
()
==
4
,
"weight nDim is not 4"
);
static_assert
(
in_desc
.
GetLength
(
I1
)
==
wei_desc
.
GetLength
(
I1
),
"input & weight dimension not consistent"
);
constexpr
index_t
N
=
in_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_desc
.
GetLength
(
I0
);
constexpr
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_desc
.
GetLength
(
I3
);
constexpr
index_t
LeftPadH
=
LeftPads
{}.
Get
(
I0
);
constexpr
index_t
LeftPadW
=
LeftPads
{}.
Get
(
I1
);
constexpr
index_t
RightPadH
=
RightPads
{}.
Get
(
I0
);
constexpr
index_t
RightPadW
=
RightPads
{}.
Get
(
I1
);
constexpr
index_t
YEff
=
(
Y
-
1
)
*
ConvDilations
{}[
0
]
+
1
;
constexpr
index_t
XEff
=
(
X
-
1
)
*
ConvDilations
{}[
1
]
+
1
;
constexpr
index_t
Ho
=
(
Hi
+
LeftPadH
+
RightPadH
-
YEff
)
/
ConvStrides
{}[
0
]
+
1
;
constexpr
index_t
Wo
=
(
Wi
+
LeftPadW
+
RightPadW
-
XEff
)
/
ConvStrides
{}[
1
]
+
1
;
return
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K
,
Ho
,
Wo
>
{});
}
template
<
typename
...
InDesc
,
typename
...
WeiDesc
,
typename
ConvStrides
,
...
...
@@ -131,30 +83,4 @@ calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, cons
return
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
}
template
<
class
Float
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
constexpr
std
::
size_t
calculate_convolution_memory_size
(
Float
,
InDesc
,
WeiDesc
,
OutDesc
)
{
using
namespace
ck
;
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
N
=
out_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_desc
.
GetLength
(
I3
);
constexpr
index_t
C
=
wei_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_desc
.
GetLength
(
I3
);
return
sizeof
(
Float
)
*
(
InDesc
::
GetElementSpace
()
+
WeiDesc
::
GetElementSpace
()
+
OutDesc
::
GetElementSpace
());
}
#endif
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment