Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
afc7d431
Commit
afc7d431
authored
Apr 24, 2022
by
carlushuang
Browse files
avx2 gemm now works for single thread
parent
07af8343
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2308 additions
and
962 deletions
+2308
-962
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+160
-85
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
...ion/cpu/device/convolution_forward_specialization_cpu.hpp
+13
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+174
-36
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+248
-424
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+108
-63
include/ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
.../ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
+3
-3
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
...tion/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
+109
-0
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+1084
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+5
-1
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+30
-46
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+363
-296
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
+2
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+9
-8
No files found.
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
afc7d431
...
...
@@ -13,21 +13,10 @@ namespace cpu {
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AccDataType
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CBlockDesc
,
typename
ABlockSliceLengths
,
typename
BBlockSliceLengths
,
typename
CBlockSliceLengths
,
typename
AThreadSliceLength
,
typename
BThreadSliceLength
,
ck
::
index_t
AThreadLoopOverDim
,
// thread slice loop over on block slice. 1d is enough for
// now
ck
::
index_t
BThreadLoopOverDim
,
typename
CDesc
,
ck
::
index_t
KPerBlock
,
...
...
@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static
constexpr
index_t
nDimA
=
ABlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimB
=
BBlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
C
Block
Desc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
CDesc
::
GetNumOfDimension
();
using
IndexA
=
MultiIndex
<
nDimA
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CBlockDesc
{},
IndexC
{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
template
<
typename
TensorDesc
>
constexpr
auto
GetLeadingElement
(
const
TensorDesc
&
desc
)
...
...
@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBlockBuffer
>
ck
::
index_t
GetALeadingElement
(
const
ABlockDesc
&
a_block_desc
)
const
{
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetBLeadingElement
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetCLeadingElement
(
const
CDesc
&
c_desc
)
const
{
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetMPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetKPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
}
ck
::
index_t
GetNPerBlock
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
i_m
*
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
return
i_m
;
}
}
ck
::
index_t
GetBBlockStartOffset
(
const
BBlockDesc
&
b_block_desc
,
const
index_t
,
const
index_t
i_n
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
i_n
;
}
else
{
// N/8 * K * 8
return
i_n
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetCBlockStartOffset
(
const
CDesc
&
c_desc
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
return
i_m
*
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBuffer
>
void
Run
(
const
ABlockDesc
&
a_block_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
IndexA
&
a_origin
,
const
IndexA
&
/*
a_origin
*/
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockBuffer
&
b_block_buf
,
const
IndexB
&
b_origin
,
const
IndexB
&
/* b_origin */
,
const
CDesc
&
c_desc
,
CBuffer
&
c_buf
,
const
IndexC
&
/* c_origin */
,
const
CBlockDesc
&
c_block_desc
,
CBlockBuffer
&
c_block_buf
,
const
IndexC
&
c_origin
)
const
bool
is_accumulate_c
=
true
)
const
{
auto
lda
=
GetALeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
constexpr
auto
m_n_block_length
=
ck
::
Sequence
<
ABlockSliceLengths
::
At
(
AThreadLoopOverDim
),
BBlockSliceLengths
::
At
(
BThreadLoopOverDim
)
>
{};
constexpr
auto
m_n_thread_length
=
ck
::
Sequence
<
AThreadSliceLength
::
At
(
AThreadLoopOverDim
),
BThreadSliceLength
::
At
(
BThreadLoopOverDim
)
>
{};
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
constexpr
auto
m_n_access_length
=
m_n_block_length
/
m_n_thread_length
;
const
auto
k_per_block
=
GetKPerBlock
(
a_block_desc
);
const
auto
m_per_block
=
GetMPerBlock
(
a_block_desc
);
const
auto
n_per_block
=
GetNPerBlock
(
b_block_desc
);
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
constexpr
auto
ordered_m_n_access_length
=
container_reorder_given_new2old
(
m_n_access_length
,
ThreadMNAccessOrder
{});
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
Kr
=
k_per_block
;
param
.
lda
=
lda
;
param
.
ldb
=
ldb
;
param
.
ldc
=
ldc
;
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
{
auto
current_mr
=
ck
::
math
::
min
(
m_per_block
-
i_m
,
m_per_thread
);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
GetABlockStartOffset
(
a_block_desc
,
i_m
,
0
)];
constexpr
auto
a_block_idx_zeros
=
typename
uniform_sequence_gen
<
nDimA
,
0
>::
type
{};
// starting point of the block
constexpr
auto
b_block_idx_zeros
=
typename
uniform_sequence_gen
<
nDimB
,
0
>::
type
{};
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr
auto
lda
=
GetLeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
constexpr
auto
ldb
=
GetLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
constexpr
auto
ldc
=
GetLeadingElement
(
c_block_desc
)
*
sizeof
(
FloatC
);
for
(
ck
::
index_t
i_n
=
0
;
i_n
<
n_per_block
;
i_n
+=
n_per_thread
)
{
auto
current_nr
=
ck
::
math
::
min
(
n_per_block
-
i_n
,
n_per_thread
);
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
Kr
=
KPerBlock
;
param
.
lda
=
lda
;
param
.
ldb
=
ldb
;
param
.
ldc
=
ldc
;
param
.
alpha
=
1.0
f
;
// TODO
static_ford
<
decltype
(
ordered_m_n_access_length
)
>
{}([
&
](
auto
ordered_idx
)
{
constexpr
auto
origin_m_n_idx
=
ordered_idx
.
ReorderGivenOld2New
(
ThreadMNAccessOrder
{});
constexpr
auto
current_m_idx
=
origin_m_n_idx
.
At
(
0
)
*
AThreadSliceLength
::
At
(
AThreadLoopOverDim
);
constexpr
auto
current_n_idx
=
origin_m_n_idx
.
At
(
1
)
*
BThreadSliceLength
::
At
(
BThreadLoopOverDim
);
constexpr
auto
current_mr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
0
)
-
current_m_idx
,
m_n_thread_length
.
At
(
0
));
constexpr
auto
current_nr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
1
)
-
current_n_idx
,
m_n_thread_length
.
At
(
1
));
constexpr
auto
a_block_idx
=
a_block_idx_zeros
.
Modify
(
AThreadLoopOverDim
,
current_m_idx
);
constexpr
auto
a_block_coord
=
make_tensor_coordinate
(
a_block_desc
,
to_multi_index
(
a_origin
+
a_block_idx
));
constexpr
auto
b_block_idx
=
b_block_idx_zeros
.
Modify
(
BThreadLoopOverDim
,
current_n_idx
);
constexpr
auto
b_block_coord
=
make_tensor_coordinate
(
b_block_desc
,
to_multi_index
(
b_origin
+
b_block_idx
));
constexpr
auto
c_block_coord
=
make_tensor_coordinate
(
c_block_desc
,
to_multi_index
(
c_origin
+
origin_m_n_idx
));
param
.
p_a
=
&
a_block_buf
.
p_data_
[
a_block_coord
.
GetOffset
()];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
b_block_coord
.
GetOffset
()];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
c_block_coord
.
GetOffset
()];
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
});
param
.
p_b
=
&
b_block_buf
.
p_data_
[
GetBBlockStartOffset
(
b_block_desc
,
0
,
i_n
)];
param
.
p_c
=
&
c_buf
.
p_data_
[
GetCBlockStartOffset
(
c_desc
,
i_m
,
i_n
)];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
}
}
}
}
};
...
...
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
View file @
afc7d431
...
...
@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC
,
};
enum
ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop
,
NHWC_GemmKLoopOverC
,
// not merge c*y*x, and c % k_per_block == 0
};
enum
ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver
,
LoopOver_MNK
,
LoopOver_MKN
,
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
afc7d431
...
...
@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -23,20 +25,21 @@ namespace device {
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
>
// bool IsGemmMPadded
,
// bool IsGemmNPadded
,
// bool IsGemmKPadded
>
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
>
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
...
...
@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
ck
::
index_t
gemm_n_padded
=
math
::
integer_least_multiple
(
gemm_n
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
const
auto
wei_gemm_n_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
const
auto
wei_gemm_
n0_k_n1
_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemm_
padn_k
_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_n_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I0
)
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_pass_through_transform
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
make_right_pad_transform
(
gemm_n
,
gemm_n_padded
-
gemm_n
),
make_pass_through_transform
(
gemm_k
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
wei_gemm_n0_k_n1_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_padn_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I0
)
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_pass_through_transform
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
...
...
@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
...
...
@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
K
;
const
index_t
GemmN
=
GetGemmN
(
K
)
;
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
...
...
@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
bool
UseCLocalBuffer
=
true
;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
InDataType
,
InDataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
<
WeiDataType
,
WeiDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
OutDataType
,
OutDataType
,
CGridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
WeiDataType
,
// WeiDataType,
OutDataType
,
// OutDataType,
AccDataType
,
// AccDataType,
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
...
...
@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ck
::
Sequence
<
0
,
1
,
2
>
,
// BlockMNKAccessOrder,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
...
...
@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
InDataType
,
WeiDataType
,
...
...
@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation
,
CElementwiseOperation
>
;
float
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
float
ave_time
=
0
;
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
a
_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c
_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
arg
.
p_a_grid_
,
...
...
@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
...
...
@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
">"
;
<<
"DFwdAvx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
afc7d431
...
...
@@ -7,7 +7,9 @@
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <unistd.h>
namespace
ck
{
namespace
cpu
{
...
...
@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AccDataType
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
...
...
@@ -57,334 +58,92 @@ template <typename FloatA,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer. if false, will write to C directly
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
struct
GridwiseGemmAvx2_MxN
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
auto
GetABlockDescriptor
()
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
constexpr
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBloc
k
));
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_bl
k
));
return
a_block_desc_m_k
;
}
else
{
// A : K, M
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBloc
k
,
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_bl
k
,
math
::
integer_least_multiple
(
MPerBloc
k
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
m_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
return
a_block_desc_k_m
;
}
}
static
constexpr
auto
GetBBlockDescriptor
()
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
math
::
integer_least_multiple
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)));
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
return
b_block_desc_k_n
;
}
else
{
// B : N/8, K, N8
constexpr
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBloc
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBloc
k
,
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
static
constexpr
auto
GetABlockSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
Sequence
<
MPerBlock
,
KPerBlock
>
{};
}
else
{
// A : K, M
return
ck
::
Sequence
<
KPerBlock
,
MPerBlock
>
{};
}
}
static
constexpr
auto
GetBBlockSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
KPerBlock
,
NPerBlock
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
NPerBlock
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
>
{};
}
}
static
constexpr
auto
GetABlockDimAccessOrder
()
{
return
ck
::
Sequence
<
0
,
1
>
{};
}
static
constexpr
auto
GetBBlockDimAccessOrder
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
0
,
1
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
}
}
static
constexpr
auto
GetABlockMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
0
,
KPerBlock
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
KPerBlock
,
0
);
}
}
static
constexpr
auto
GetBBlockMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
KPerBlock
,
0
);
}
else
{
// B : N/8, K, N88;
return
ck
::
make_multi_index
(
0
,
KPerBlock
,
0
);
}
}
#if 0
static constexpr auto GetAThreadDiscriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, ck::tensor_layout::gemm::RowMajor>::value){
// A : M, K
constexpr auto a_thread_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock));
return a_thread_desc_m_k;
} else {
// A : K, M
constexpr auto a_thread_desc_k_m = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr));
return a_thread_desc_k_m;
}
}
static constexpr auto GetBThreadDescriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, ck::tensor_layout::gemm::RowMajor>::value){
// B : K, N
constexpr auto b_thread_desc_k_n = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr));
return b_thread_desc_k_n;
} else {
// B : N/8, K, N8
constexpr auto b_thread_desc_n_k_n8 = make_naive_tensor_descriptor_packed(make_tuple(math::integer_divide_ceil(ThreadwiseGemm_Dispatch::ThreadMaxNr, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_thread_desc_n_k_n8;
}
}
#endif
static
constexpr
auto
GetAThreadSliceLength
()
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
,
KPerBlock
>
{};
}
else
{
// A : K, M
return
ck
::
Sequence
<
KPerBlock
,
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
>
{};
}
}
static
constexpr
auto
GetBThreadSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
KPerBlock
,
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
>
{};
}
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
static
constexpr
auto
GetAThreadMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
,
0
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
0
,
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
);
}
}
static
constexpr
auto
GetBThreadMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
0
,
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
);
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
0
,
0
>
{};
}
}
static
constexpr
ck
::
index_t
GetAThreadLoopOverDim
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
0
;
}
else
{
// A : K, M
return
1
;
}
}
static
constexpr
ck
::
index_t
GetBThreadLoopOverDim
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
1
;
}
else
{
// B : N/8, K, N88;
return
0
;
}
}
static
constexpr
auto
GetCBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// TODO:
}
}
static
constexpr
auto
GetCBlockSliceLength
()
{
return
ck
::
Sequence
<
MPerBlock
,
NPerBlock
>
{};
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
{
#if 0
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
is_valid
=
true
;
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
is_valid
&=
false
;
}
else
{
return false;
// TODO: need check c grid is simple transform?
if
(
GemmN
%
8
!=
0
)
is_valid
&=
false
;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
is_valid
;
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
...
...
@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
ck
::
index_t
m_per_block
;
ck
::
index_t
n_per_block
;
ck
::
index_t
k_per_block
;
if
constexpr
(
MPerBlock
==
0
&&
NPerBlock
==
0
&&
KPerBlock
==
0
)
{}
else
{
m_per_block
=
MPerBlock
;
n_per_block
=
NPerBlock
;
k_per_block
=
KPerBlock
;
}
const
auto
M
=
a_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc
.
GetLength
(
I0
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
M
,
m_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
N
,
n_per_block
);
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
constexpr
auto
a_block_desc
=
GetABlockDescriptor
();
constexpr
auto
a_block_slice_length
=
GetABlockSliceLength
();
constexpr
auto
a_block_copy_dim
=
decltype
(
a_block_slice_length
)
::
Size
();
constexpr
auto
a_dim_access_order
=
GetABlockDimAccessOrder
();
constexpr
auto
a_block_move_step
=
GetABlockMoveFwdStep
();
constexpr
auto
a_thread_slice_length
=
GetAThreadSliceLength
();
constexpr
auto
a_thread_loop_over_dim
=
GetAThreadLoopOverDim
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor
();
constexpr
auto
b_block_slice_length
=
GetBBlockSliceLength
();
constexpr
auto
b_block_copy_dim
=
decltype
(
b_block_slice_length
)
::
Size
();
constexpr
auto
b_dim_access_order
=
GetBBlockDimAccessOrder
();
constexpr
auto
b_block_move_step
=
GetBBlockMoveFwdStep
();
constexpr
auto
b_thread_slice_length
=
GetBThreadSliceLength
();
constexpr
auto
b_thread_loop_over_dim
=
GetBThreadLoopOverDim
();
constexpr
auto
c_block_desc
=
GetCBlockDescriptor
();
constexpr
auto
c_block_slice_length
=
GetCBlockSliceLength
();
constexpr
auto
c_block_move_step
=
ck
::
make_multi_index
(
0
,
NPerBlock
);
auto
a_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatA
,
// SrcData
FloatA
,
// DstData
decltype
(
a_grid_desc
),
// SrcDesc
decltype
(
a_block_desc
),
// DstDesc
AElementwiseOperation
,
// ElementwiseOperation
decltype
(
a_block_slice_length
),
// SliceLengths
decltype
(
a_dim_access_order
),
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
false
,
// SrcResetCoordinateAfterRun
true
// DstResetCoordinateAfterRun
>
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
a_block_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatB
,
// SrcData
FloatB
,
// DstData
decltype
(
b_grid_desc
),
// SrcDesc
decltype
(
b_block_desc
),
// DstDesc
BElementwiseOperation
,
// ElementwiseOperation
decltype
(
b_block_slice_length
),
// SliceLengths
decltype
(
b_dim_access_order
),
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
false
,
// SrcResetCoordinateAfterRun
true
// DstResetCoordinateAfterRun
>
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
b_block_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
auto
c_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatC
,
// SrcData
FloatC
,
// DstData
decltype
(
c_block_desc
),
// SrcDesc
decltype
(
c_grid_desc
),
// DstDesc
BElementwiseOperation
,
// ElementwiseOperation
ck
::
Sequence
<
MPerBlock
,
NPerBlock
>
,
// SliceLengths
ck
::
Sequence
<
0
,
1
>
,
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
true
,
// SrcResetCoordinateAfterRun
false
// DstResetCoordinateAfterRun
>
(
c_block_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
MPerBlock
*
KPerBlock
*
sizeof
(
FloatA
),
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
KPerBlock
*
NPerBlock
*
sizeof
(
FloatB
),
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
MPerBlock
*
NPerBlock
*
sizeof
(
FloatC
),
MemAlignmentByte
);
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum_t
::
Global
>
(
ck
::
index_t
m_per_block
=
MPerBlock
;
ck
::
index_t
n_per_block
=
NPerBlock
;
ck
::
index_t
k_per_block
=
KPerBlock
;
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
constexpr
auto
a_block_copy_dim
=
AGridDesc
::
GetNumOfDimension
();
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
),
MemAlignmentByte
);
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum_t
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
),
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
));
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
AccDataType
,
// AccDataType,
decltype
(
a_block_desc
),
// ABlockDesc,
decltype
(
b_block_desc
),
// BBlockDesc,
decltype
(
c_block_desc
),
// CBlockDesc,
decltype
(
a_block_slice_length
),
// ABlockSliceLengths,
decltype
(
b_block_slice_length
),
// BBlockSliceLengths,
decltype
(
c_block_slice_length
),
// CBlockSliceLengths,
decltype
(
a_thread_slice_length
),
// AThreadSliceLength,
decltype
(
b_thread_slice_length
),
// BThreadSliceLength,
a_thread_loop_over_dim
,
// AThreadLoopOverDim, // thread slice
// loop over on block slice. 1d is enough
// for now
b_thread_loop_over_dim
,
// BThreadLoopOverDim,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
:
c_grid_desc
.
GetElementSpaceSize
());
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
,
0
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel for
for
(
ck
::
index_t
gid
=
0
;
gid
<
grid_size
;
gid
++
)
{
ck
::
index_t
i_mc
=
(
gid
/
grid_n
)
*
m_per_block
;
ck
::
index_t
i_nc
=
(
gid
%
grid_n
)
*
n_per_block
;
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
M
-
i_mc
,
m_per_block
);
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
N
-
i_nc
,
n_per_block
);
// pack_b
b_threadwise_copy
.
RunGeneric
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_move_step
);
if
(
i_nc
==
0
)
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
i_nc
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
));
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
}
else
{
// pack_a
a_threadwise_copy
.
RunGeneric
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_move_step
);
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
K
;
i_kc
+=
k_per_block
)
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
Gemm
K
;
i_kc
+=
k_per_block
)
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
K
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
Gemm
K
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
// printf("A ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//}
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__);
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
...
...
@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
());
make_zero_multi_index
<
2
>
(),
i_kc
!=
0
);
// printf("[%d] 2222 \n",__LINE__);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
{
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
// printf("[%d] 2222 \n",__LINE__);
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
}
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
// i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
if
constexpr
(
UseCLocalBuffer
)
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
n_per_block
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
);
// only parallel in gemm m dim
#pragma omp parallel for
for
(
ck
::
index_t
i_mc
=
0
;
i_mc
<
GemmM
;
i_mc
+=
m_per_block
)
{
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
{
c_threadwise_copy
.
RunGeneric
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
MoveDstSliceWindow
(
c_grid_desc
,
c_block_move_step
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
// TODO: if use local C buffer, then this nc loop need to loop only once
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
{
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
!
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
}
if
((
i_kc
+
k_per_block
)
<
GemmK
)
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
}
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
afc7d431
...
...
@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_
gemm_
param.hpp"
namespace
ck
{
namespace
cpu
{
...
...
@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
"test %%edi, %%edi
\n
"
"je L_GemmAvx2_MxN_6x16_Store_C%=
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
...
...
@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_6x16_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
...
@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
// clang-format off
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
if
(
param
->
accmulate_c
){
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
#pragma unroll
...
...
@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
if
constexpr
(
NonTemporalStore
)
{
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm1
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
...
...
@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
"test %%edi, %%edi
\n
"
"je L_GemmAvx2_MxN_4x24_Store_C%=
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Nr >16)
\n
vaddps 64(%%rax), %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm3, %%ymm3
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vaddps 32(%%rbx), %%ymm4, %%ymm4
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr >16)
\n
vaddps 64(%%rbx), %%ymm5, %%ymm5
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vaddps (%%rcx), %%ymm6, %%ymm6
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vaddps 32(%%rcx), %%ymm7, %%ymm7
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr >16)
\n
vaddps 64(%%rcx), %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vaddps (%%rdx), %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vaddps 64(%%rdx), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_4x24_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
...
@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
// clang-format off
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
if
(
param
->
accmulate_c
)
{
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
#pragma unroll
...
...
@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
6
][
2
]
=
{
{
ThreadwiseGemm_
6x16
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
1x16
_t
::
Run
,
},
{
ThreadwiseGemm_
5x16
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
2x16
_t
::
Run
,
},
{
ThreadwiseGemm_
4x16
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
3x16
_t
::
Run
,
},
{
ThreadwiseGemm_
3x16
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
4x16
_t
::
Run
,
},
{
ThreadwiseGemm_
2x16
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
5x16
_t
::
Run
,
},
{
ThreadwiseGemm_
1x16
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
6x16
_t
::
Run
,
},
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
5
&&
in
>=
0
&&
in
<=
1
);
return
dispatch_table
[
mr
][
nr
](
param
);
}
};
...
...
@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
4
][
3
]
=
{
{
ThreadwiseGemm_
4x24
_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
1x24
_t
::
Run
,
},
{
ThreadwiseGemm_
3x24
_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
2x24
_t
::
Run
,
},
{
ThreadwiseGemm_
2x24
_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
3x24
_t
::
Run
,
},
{
ThreadwiseGemm_
1x24
_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
4x24
_t
::
Run
,
},
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
return
dispatch_table
[
mr
][
nr
](
param
);
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
3
&&
in
>=
0
&&
in
<=
2
);
return
dispatch_table
[
im
][
in
](
param
);
}
};
...
...
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
→
include/ck/tensor_operation/cpu/thread/threadwise_
gemm_
param.hpp
View file @
afc7d431
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_
GEMM_
PARAM_HPP
#define CK_THREADWISE_
GEMM_
PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
...
...
@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t
ldb
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
float
alpha
;
uint32_t
_pack0
;
int
accmulate_c
;
// if 1, need load C and add into current fma. if 0, direct store out c result
}
__attribute__
((
packed
));
}
// namespace cpu
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
View file @
afc7d431
...
...
@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
int
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
int
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
int
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
int
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
int
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
int
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
int
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
int
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
printf
(
"N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d
\n
"
,
N
,
Hi
,
Wi
,
C
,
Ho
,
Wo
,
Fy
,
Fx
,
Dy
,
Sy
,
Dx
,
Sx
,
Py
,
Px
);
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
std
::
cout
<<
"src hidden:"
<<
SrcDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
std
::
cout
<<
"dst hidden:"
<<
DstDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
...
...
@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
#endif
const
auto
src_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
src_slice_step
=
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
const
auto
dst_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
dst_slice_step
=
make_tensor_coordinate_step
(
dst_desc
,
to_multi_index
(
dst_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
for
(
auto
idx_id
=
0
;
idx_id
<
num_access
;
idx_id
++
)
{
using
src_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
printf
(
"[%s] "
,
is_src_valid
?
"y"
:
"n"
);
print_multi_index
(
src_coord_
.
GetIndex
());
printf
(
"----"
);
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_
(
dst_vector_container
.
template
AsType
<
dst_vector_t
>(),
src_vector_container
.
template
AsType
<
src_vector_t
>());
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
printf
(
" -> "
);
print_multi_index
(
dst_coord_
.
GetIndex
());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf
(
"
\n
"
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>());
// move coordinate
if
(
idx_id
!=
num_access
-
1
)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_slice_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_slice_step
);
}
}
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
0 → 100644
View file @
afc7d431
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp"
#include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
namespace
ck
{
namespace
cpu
{
namespace
avx2_util
{
inline
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
_mm256_loadu_ps
(
p_src
+
0
));
_mm256_storeu_ps
(
p_dst
+
8
,
_mm256_loadu_ps
(
p_src
+
8
));
p_dst
+=
16
;
p_src
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
_mm256_loadu_ps
(
p_src
));
p_dst
+=
8
;
p_src
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
_mm_loadu_ps
(
p_src
));
p_dst
+=
4
;
p_src
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
_mm_loadu_si64
(
p_src
));
p_dst
+=
2
;
p_src
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
*
p_src
;
}
}
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
__m256
ymm
=
_mm256_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
__m128
xmm
=
_mm_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
ymm
);
_mm256_storeu_ps
(
p_dst
+
8
,
ymm
);
p_dst
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
ymm
);
p_dst
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
xmm
);
p_dst
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
xmm
);
p_dst
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
*
reinterpret_cast
<
const
float
*>
(
&
value
);
}
}
inline
void
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
)
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
__m256
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
r0
=
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
);
r1
=
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
);
r2
=
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
);
r3
=
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
);
r4
=
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
);
r5
=
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
);
r6
=
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
);
r7
=
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
);
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
t2
=
_mm256_unpacklo_ps
(
r2
,
r3
);
t3
=
_mm256_unpackhi_ps
(
r2
,
r3
);
t4
=
_mm256_unpacklo_ps
(
r4
,
r5
);
t5
=
_mm256_unpackhi_ps
(
r4
,
r5
);
t6
=
_mm256_unpacklo_ps
(
r6
,
r7
);
t7
=
_mm256_unpackhi_ps
(
r6
,
r7
);
r0
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r1
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r2
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r3
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r4
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r5
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r6
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r7
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
t0
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x20
);
t1
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x20
);
t2
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x20
);
t3
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x20
);
t4
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x31
);
t5
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x31
);
t6
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x31
);
t7
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x31
);
_mm256_storeu_ps
(
p_dst
+
0
*
stride_dst
,
t0
);
_mm256_storeu_ps
(
p_dst
+
1
*
stride_dst
,
t1
);
_mm256_storeu_ps
(
p_dst
+
2
*
stride_dst
,
t2
);
_mm256_storeu_ps
(
p_dst
+
3
*
stride_dst
,
t3
);
_mm256_storeu_ps
(
p_dst
+
4
*
stride_dst
,
t4
);
_mm256_storeu_ps
(
p_dst
+
5
*
stride_dst
,
t5
);
_mm256_storeu_ps
(
p_dst
+
6
*
stride_dst
,
t6
);
_mm256_storeu_ps
(
p_dst
+
7
*
stride_dst
,
t7
);
}
}
// namespace avx2_util
using
ConvolutionForwardSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
;
using
ConvolutionForwardGemmKSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
;
// assume input -> a matrix
// assume input -> MC * KC
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
N
=
1
;
Hi
=
1
;
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
// gemm_m
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// gemm_k
Ho
=
1
;
Wo
=
Wi
;
Fy
=
1
;
Fx
=
1
;
Dy
=
1
;
Sy
=
1
;
Dx
=
1
;
Sx
=
1
;
Py
=
0
;
Px
=
0
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Fy
=
1
;
Fx
=
1
;
Dy
=
1
;
Sy
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
coefficients_
[
Number
<
0
>
{}];
Dx
=
1
;
Sx
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
coefficients_
[
Number
<
0
>
{}];
Py
=
0
;
Px
=
0
;
}
else
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
}
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi
=
Sx
*
C
;
input_offset_ovf_wi_acc_hi
=
Sy
*
Wi
*
C
-
Wo
*
Sx
*
C
;
input_offset_ovf_hi_acc_n
=
Hi
*
Wi
*
C
-
Ho
*
Sy
*
Wi
*
C
;
// input_offset_acc_c = 1;
input_offset_ovf_c_acc_x
=
Dx
*
C
-
C
;
input_offset_ovf_x_acc_y
=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
src_offset
=
-
Py
*
Wi
*
C
-
Px
*
C
;
i_n
=
0
;
i_c
=
0
;
i_hi
=
-
Py
;
i_wi
=
-
Px
;
i_ho
=
0
;
i_wo
=
0
;
i_y
=
0
;
i_x
=
0
;
i_gemm_k
=
0
;
#if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
#endif
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
i_wi
=
idx_m
;
i_c
=
idx_k
;
src_offset
=
i_wi
*
C
+
i_c
;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c
=
idx_k
;
i_x
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
;
i_wi
=
i_wo
*
Sx
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
}
else
{
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
(
idx_k
==
0
)
{
i_c
=
0
;
i_x
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
-
Py
;
i_wi
=
i_wo
*
Sx
-
Px
;
}
else
{
i_c
=
idx_k
%
C
;
i_x
=
(
idx_k
/
C
)
%
Fx
;
i_y
=
(
idx_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
}
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
p_src
;
}
else
{
const
ck
::
index_t
m_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block);
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
ck
::
index_t
i_m_itr
=
m_per_block
;
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
);
i_m_itr
-=
8
;
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
C
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
p_dst
+=
4
*
k_per_block
;
p_src
+=
4
*
C
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+=
2
*
k_per_block
;
p_src
+=
2
*
C
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
while
(
i_m_itr
>
0
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
p_dst
+=
k_per_block
;
i_wo_itr
++
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_ho_itr
++
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
}
else
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
// c % k_per_block == 0, so every time k_per_block here is the same
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while
(
i_m_itr
>
0
)
{
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
else
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
// printf("[%d] \n", __LINE__);
}
else
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while
(
i_m_itr
>
0
)
{
/*** go along Gemm K ***/
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
ck
::
index_t
i_wi_itr_k
=
i_wi_itr
;
ck
::
index_t
i_hi_itr_k
=
i_hi_itr
;
ck
::
index_t
i_c_itr_k
=
i_c
;
ck
::
index_t
i_y_itr_k
=
i_y
;
ck
::
index_t
i_x_itr_k
=
i_x
;
ck
::
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>
0
)
{
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
);
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
p_dst_k
+=
current_k_block
;
p_src_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
;
if
(
i_c_itr_k
>=
C
)
{
i_c_itr_k
=
0
;
i_x_itr_k
++
;
i_wi_itr_k
+=
Dx
;
p_src_k
+=
input_offset_ovf_c_acc_x
;
}
if
(
i_x_itr_k
>=
Fx
)
{
i_x_itr_k
=
0
;
i_y_itr_k
++
;
i_hi_itr_k
+=
Dy
;
p_src_k
+=
input_offset_ovf_x_acc_y
;
}
i_k_itr
-=
current_k_block
;
}
/*** go along Gemm K ***/
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
}
}
}
}
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
i_c
+=
move_k
;
src_offset
+=
move_k
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
i_c
+=
move_k
;
src_offset
+=
move_k
;
}
else
{
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
// c % k_per_block == 0, so every time k_per_block here is the same
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
i_c
+=
move_k
;
src_offset
+=
move_k
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if
(
i_c
>=
C
)
{
i_c
=
0
;
i_x
++
;
i_wi
+=
Dx
;
src_offset
+=
Dx
*
C
-
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
if
(
i_x
>=
Fx
)
{
i_x
=
0
;
i_y
++
;
i_wi
=
i_wi
-
Fx
*
Dx
;
i_hi
+=
Dy
;
src_offset
+=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
}
else
{
i_gemm_k
+=
move_k
;
i_c
=
i_gemm_k
%
C
;
i_x
=
(
i_gemm_k
/
C
)
%
Fx
;
i_y
=
(
i_gemm_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
}
}
}
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_n
;
ck
::
index_t
i_c
;
ck
::
index_t
i_hi
;
ck
::
index_t
i_wi
;
ck
::
index_t
i_ho
;
ck
::
index_t
i_wo
;
ck
::
index_t
i_y
;
ck
::
index_t
i_x
;
ck
::
index_t
i_gemm_k
;
ck
::
index_t
N
;
// ck::index_t K;
ck
::
index_t
C
;
ck
::
index_t
Hi
;
ck
::
index_t
Wi
;
ck
::
index_t
Ho
;
ck
::
index_t
Wo
;
ck
::
index_t
Sy
;
ck
::
index_t
Sx
;
ck
::
index_t
Dy
;
ck
::
index_t
Dx
;
ck
::
index_t
Py
;
ck
::
index_t
Px
;
ck
::
index_t
Fy
;
ck
::
index_t
Fx
;
intptr_t
input_offset_acc_wi
;
intptr_t
input_offset_ovf_wi_acc_hi
;
intptr_t
input_offset_ovf_hi_acc_n
;
// intptr_t input_offset_acc_c;
intptr_t
input_offset_ovf_c_acc_x
;
intptr_t
input_offset_ovf_x_acc_y
;
intptr_t
src_offset
;
// keep this as pointer type in case we have negative offset
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
GemmN1
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_n0
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_n1
=
src_slice_origin_idx
[
Number
<
2
>
{}];
i_gemm_n
=
idx_n0
*
GemmN1
+
idx_n1
;
// i_gemm_k = idx_k;
src_offset
=
idx_n0
*
GemmK
*
GemmN1
+
idx_k
+
idx_n1
*
GemmN1
;
// Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
// TODO: weight NHWC not support this
}
else
{
const
ck
::
index_t
n_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for
(
index_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
{
ck
::
index_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
8
);
ck
::
index_t
i_k_itr
=
k_per_block
;
if
(
current_n_8
==
8
)
{
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
while
(
i_k_itr
>=
8
)
{
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
);
p_dst_k
+=
8
*
8
;
p_src_k
+=
8
;
i_k_itr
-=
8
;
}
if
(
i_k_itr
&
4
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
2
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
2
];
p_dst_k
[
3
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
3
];
p_dst_k
+=
4
*
8
;
p_src_k
+=
4
;
}
if
(
i_k_itr
&
2
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
+=
2
*
8
;
p_src_k
+=
2
;
}
if
(
i_k_itr
&
1
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
}
}
else
{
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
for
(
index_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
{
for
(
index_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
{
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
float
v
=
i_current_n_itr
<
GemmN
?
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
]
:
.0
f
;
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
}
}
}
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
GemmK
;
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
ck
::
index_t
move_n0
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset
+=
move_k
+
move_n0
*
GemmK
*
GemmN1
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_gemm_n
;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck
::
index_t
GemmN1
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
dst_desc
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
DstGemmM
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
DstGemmN
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
src_offset
=
0
;
dst_offset
=
0
;
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
if
constexpr
(
BypassTransfer
)
{
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
i_src_gemm_m
*
DstGemmN
+
i_src_gemm_n
;
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
dst_slice_origin_idx
)
{
i_dst_gemm_m
=
dst_slice_origin_idx
[
Number
<
0
>
{}];
i_dst_gemm_n
=
dst_slice_origin_idx
[
Number
<
1
>
{}];
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
m_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
const
ck
::
index_t
n_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_src
+=
8
*
n_per_block
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_dst_gemm_m
;
ck
::
index_t
i_dst_gemm_n
;
ck
::
index_t
DstGemmM
;
ck
::
index_t
DstGemmN
;
intptr_t
src_offset
;
intptr_t
dst_offset
;
};
}
// namespace cpu
}
// namespace ck
#endif
library/include/ck/library/host_tensor/device.hpp
View file @
afc7d431
...
...
@@ -121,7 +121,11 @@ template <typename... Args, typename F>
float
launch_and_time_cpu_kernel
(
F
kernel
,
int
nrepeat
,
Args
...
args
)
{
WallTimer
timer
;
kernel
(
args
...);
int
nwarmup
=
3
;
for
(
int
i
=
0
;
i
<
nwarmup
;
i
++
)
kernel
(
args
...);
timer
.
Start
();
for
(
int
i
=
0
;
i
<
nrepeat
;
i
++
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
afc7d431
...
...
@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
using
P
assThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
P
T
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
...
...
@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
256
,
128
,
64
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
512
,
256
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
1024
,
144
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>>
;
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
120
,
64
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
144
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
288
,
128
,
4
,
24
,
true
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
afc7d431
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#define AVX2_DATA_ALIGNMENT 32
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
float
max_diff
=
1e-6
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
{
return
false
;
}
}
return
true
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
data_type
=
0
;
int
init_method
=
0
;
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
conv_stride_h
=
2
;
ck
::
index_t
conv_stride_w
=
2
;
ck
::
index_t
conv_dilation_h
=
1
;
ck
::
index_t
conv_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
{
data_type
=
1
;
init_method
=
1
;
}
else
if
(
argc
==
3
)
{
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
}
else
if
(
argc
==
18
)
{
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
N
=
std
::
stoi
(
argv
[
3
]);
K
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
conv_stride_h
=
std
::
stoi
(
argv
[
10
]);
conv_stride_w
=
std
::
stoi
(
argv
[
11
]);
conv_dilation_h
=
std
::
stoi
(
argv
[
12
]);
conv_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
}
else
{
printf
(
"arg1: data type (0=fp32, 1=fp16)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
,
auto
acc_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
ReferenceConvBwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
{{
Hi
,
Wi
}};
const
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
{{
Y
,
X
}};
const
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
{{
Ho
,
Wo
}};
const
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
{{
conv_stride_h
,
conv_stride_w
}};
const
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
{{
conv_dilation_h
,
conv_dilation_w
}};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{{
in_left_pad_h
,
in_left_pad_w
}};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{{
in_right_pad_h
,
in_right_pad_w
}};
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
H_
,
std
::
size_t
W_
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N_
,
C_
,
H_
,
W_
}),
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
};
Tensor
<
OutDataType
>
out_n_ho_wo_k
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
WeiDataType
>
wei_k_y_x_c
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
InDataType
>
in_n_hi_wi_c_host_result
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
Tensor
<
InDataType
>
in_n_hi_wi_c_device_result
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_hi_wi_c_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_y_x_c
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_ho_wo_k
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
break
;
case
2
:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
break
;
default:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{
1
});
}
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_hi_wi_c_device_result
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
out_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
// reset input to zero
in_n_hi_wi_c_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
0
});
in_device_buf
.
ToDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
// get host result
{
auto
ref_conv
=
ReferenceConvFwdInstance
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_hi_wi_c_host_result
,
wei_k_y_x_c
,
out_n_ho_wo_k
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
if
(
conv_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
}
// profile device Conv instances
bool
success
=
true
;
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
1
);
in_device_buf
.
FromDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
if
(
!
check_out
(
in_n_hi_wi_c_host_result
,
in_n_hi_wi_c_device_result
))
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
}
else
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
else
{
std
::
cout
<<
"Not support Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
if
(
success
)
{
std
::
cout
<<
"test conv2d fwd cpu : Pass"
<<
std
::
endl
;
return
0
;
}
else
{
std
::
cout
<<
"test conv2d fwd cpu: Fail "
<<
std
::
endl
;
return
-
1
;
}
};
if
(
data_type
==
0
)
{
return
Run
(
F32
(),
F32
(),
F32
(),
F32
());
}
else
if
(
data_type
==
1
)
{
return
Run
(
F16
(),
F16
(),
F16
(),
F32
());
}
else
{
return
1
;
}
}
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
int
error_count
=
0
;
float
max_diff
=
1e-6
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
{
error_count
++
;
printf
(
"idx:%3d, ref:%f, res:%f (diff:%f)
\n
"
,
i
,
double
(
ref
.
mData
[
i
]),
double
(
result
.
mData
[
i
]),
diff
);
}
}
return
error_count
==
0
;
}
float
calculate_gflops
()
{}
int
main
(
int
argc
,
char
*
argv
[])
{
int
data_type
=
0
;
int
init_method
=
0
;
// Conv shape
ck
::
index_t
N
=
2
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
conv_stride_h
=
1
;
ck
::
index_t
conv_stride_w
=
1
;
ck
::
index_t
conv_dilation_h
=
1
;
ck
::
index_t
conv_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
{
data_type
=
0
;
init_method
=
1
;
}
else
if
(
argc
==
3
)
{
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
}
else
if
(
argc
==
18
)
{
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
N
=
std
::
stoi
(
argv
[
3
]);
K
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
conv_stride_h
=
std
::
stoi
(
argv
[
10
]);
conv_stride_w
=
std
::
stoi
(
argv
[
11
]);
conv_dilation_h
=
std
::
stoi
(
argv
[
12
]);
conv_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
}
else
{
printf
(
"arg1: data type (0=fp32, 1=fp16)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
ReferenceConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
{{
Hi
,
Wi
}};
const
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
{{
Y
,
X
}};
const
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
{{
Ho
,
Wo
}};
const
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
{{
conv_stride_h
,
conv_stride_w
}};
const
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
{{
conv_dilation_h
,
conv_dilation_w
}};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{{
in_left_pad_h
,
in_left_pad_w
}};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{{
in_right_pad_h
,
in_right_pad_w
}};
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
H_
,
std
::
size_t
W_
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N_
,
C_
,
H_
,
W_
}),
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
};
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
Tensor
<
WeiDataType
>
wei_k_c_y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"LPad(H, W):"
<<
in_left_pad_h
<<
","
<<
in_left_pad_w
<<
", RPad(H, W):"
<<
in_right_pad_h
<<
","
<<
in_right_pad_w
<<
", Stride(H, W):"
<<
conv_stride_h
<<
", "
<<
conv_stride_w
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break
;
case
2
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{});
break
;
case
3
:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for
(
auto
i_n
=
0
;
i_n
<
N
;
i_n
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
{
for
(
auto
i_hi
=
0
;
i_hi
<
Hi
;
i_hi
++
)
{
for
(
auto
i_wi
=
0
;
i_wi
<
Wi
;
i_wi
++
)
{
uint32_t
v
=
PACK_32
(
i_n
,
i_c
,
i_hi
,
i_wi
);
in_n_c_hi_wi
(
i_n
,
i_c
,
i_hi
,
i_wi
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
}
}
}
}
for
(
auto
i_k
=
0
;
i_k
<
K
;
i_k
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
{
for
(
auto
i_y
=
0
;
i_y
<
Y
;
i_y
++
)
{
for
(
auto
i_x
=
0
;
i_x
<
X
;
i_x
++
)
{
uint32_t
v
=
PACK_32
(
i_k
,
i_c
,
i_y
,
i_x
);
wei_k_c_y_x
(
i_k
,
i_c
,
i_y
,
i_x
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
}
}
}
}
break
;
default:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0
,
1
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1
,
1
});
}
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_host_result
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
// get host result
{
auto
ref_conv
=
ReferenceConvFwdInstance
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi
,
wei_k_c_y_x
,
out_n_k_ho_wo_host_result
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
if
(
conv_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
}
// profile device Conv instances
bool
success
=
true
;
double
fastest_kernel_time
=
std
::
numeric_limits
<
double
>::
max
();
std
::
string
fastest_kernel_name
=
""
;
double
fastest_kernel_gflops
=
0
;
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
double
time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
10
);
double
total_flop
=
static_cast
<
double
>
(
2
)
*
N
*
C
*
Ho
*
Wo
*
K
*
Y
*
X
;
double
gflops
=
(
total_flop
*
1e-6
)
/
time
;
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
if
(
!
check_out
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
))
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
}
else
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
", Time:"
<<
time
<<
"ms, Gflops:"
<<
gflops
<<
std
::
endl
;
if
(
time
<
fastest_kernel_time
)
{
fastest_kernel_time
=
time
;
fastest_kernel_name
=
conv_ptr
->
GetTypeString
();
fastest_kernel_gflops
=
gflops
;
}
}
}
else
{
std
::
cout
<<
"Not support Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
if
(
fastest_kernel_time
!=
std
::
numeric_limits
<
double
>::
max
())
{
std
::
cout
<<
" fastest:"
<<
fastest_kernel_name
<<
", time:"
<<
fastest_kernel_time
<<
"ms, Gflops:"
<<
fastest_kernel_gflops
<<
std
::
endl
;
}
return
0
;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if
(
data_type
==
0
)
{
return
Run
(
F32
(),
F32
(),
F32
());
}
else
{
return
1
;
}
}
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
View file @
afc7d431
...
...
@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static
constexpr
ck
::
index_t
nDim
=
ck
::
remove_reference_t
<
decltype
(
input_desc
)
>::
GetNumOfDimension
();
input_desc
.
Print
();
auto
threadwise_transfer
=
threadwise_transfer_t
{
input_desc
,
ck
::
make_zero_multi_index
<
nDim
>
(),
input_cblock_desc
,
...
...
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
afc7d431
...
...
@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk,
float
*
private_c
=
mat_c
+
tid
*
m
*
n
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
p_a
=
mat_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
private_c
;
param
.
Kr
=
k
;
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
FloatA
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
p_a
=
mat_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
private_c
;
param
.
Kr
=
k
;
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
FloatA
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
accmulate_c
=
0
;
memset
(
private_c
,
0
,
m
*
n
*
sizeof
(
float
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment