Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
afc7d431
"src/turbomind/utils/word_list.h" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Commit
afc7d431
authored
Apr 24, 2022
by
carlushuang
Browse files
avx2 gemm now works for single thread
parent
07af8343
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2308 additions
and
962 deletions
+2308
-962
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+160
-85
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
...ion/cpu/device/convolution_forward_specialization_cpu.hpp
+13
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+174
-36
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+248
-424
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+108
-63
include/ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
.../ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
+3
-3
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
...tion/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
+109
-0
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+1084
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+5
-1
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+30
-46
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+363
-296
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
+2
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+9
-8
No files found.
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
afc7d431
...
...
@@ -13,21 +13,10 @@ namespace cpu {
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AccDataType
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CBlockDesc
,
typename
ABlockSliceLengths
,
typename
BBlockSliceLengths
,
typename
CBlockSliceLengths
,
typename
AThreadSliceLength
,
typename
BThreadSliceLength
,
ck
::
index_t
AThreadLoopOverDim
,
// thread slice loop over on block slice. 1d is enough for
// now
ck
::
index_t
BThreadLoopOverDim
,
typename
CDesc
,
ck
::
index_t
KPerBlock
,
...
...
@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static
constexpr
index_t
nDimA
=
ABlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimB
=
BBlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
C
Block
Desc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
CDesc
::
GetNumOfDimension
();
using
IndexA
=
MultiIndex
<
nDimA
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CBlockDesc
{},
IndexC
{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
template
<
typename
TensorDesc
>
constexpr
auto
GetLeadingElement
(
const
TensorDesc
&
desc
)
...
...
@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBlockBuffer
>
ck
::
index_t
GetALeadingElement
(
const
ABlockDesc
&
a_block_desc
)
const
{
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetBLeadingElement
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetCLeadingElement
(
const
CDesc
&
c_desc
)
const
{
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetMPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetKPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
}
ck
::
index_t
GetNPerBlock
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
i_m
*
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
return
i_m
;
}
}
ck
::
index_t
GetBBlockStartOffset
(
const
BBlockDesc
&
b_block_desc
,
const
index_t
,
const
index_t
i_n
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
i_n
;
}
else
{
// N/8 * K * 8
return
i_n
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetCBlockStartOffset
(
const
CDesc
&
c_desc
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
return
i_m
*
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBuffer
>
void
Run
(
const
ABlockDesc
&
a_block_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
IndexA
&
a_origin
,
const
IndexA
&
/*
a_origin
*/
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockBuffer
&
b_block_buf
,
const
IndexB
&
b_origin
,
const
CBlockDesc
&
c_block_desc
,
CBlockBuffer
&
c_block_buf
,
const
IndexC
&
c_origin
)
const
{
constexpr
auto
m_n_block_length
=
ck
::
Sequence
<
ABlockSliceLengths
::
At
(
AThreadLoopOverDim
),
BBlockSliceLengths
::
At
(
BThreadLoopOverDim
)
>
{};
constexpr
auto
m_n_thread_length
=
ck
::
Sequence
<
AThreadSliceLength
::
At
(
AThreadLoopOverDim
),
BThreadSliceLength
::
At
(
BThreadLoopOverDim
)
>
{};
const
IndexB
&
/* b_origin */
,
constexpr
auto
m_n_access_length
=
m_n_block_length
/
m_n_thread_length
;
const
CDesc
&
c_desc
,
CBuffer
&
c_buf
,
const
IndexC
&
/* c_origin */
,
constexpr
auto
ordered_m_n_access_length
=
container_reorder_given_new2old
(
m_n_access_length
,
ThreadMNAccessOrder
{});
bool
is_accumulate_c
=
true
)
const
{
auto
lda
=
GetALeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
constexpr
auto
a_block_idx_zeros
=
typename
uniform_sequence_gen
<
nDimA
,
0
>::
type
{};
// starting point of the block
constexpr
auto
b_block_idx_zeros
=
typename
uniform_sequence_gen
<
nDimB
,
0
>::
type
{};
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
constexpr
auto
lda
=
GetLeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
constexpr
auto
ldb
=
GetLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
constexpr
auto
ldc
=
GetLeadingElement
(
c_block_desc
)
*
sizeof
(
FloatC
);
const
auto
k_per_block
=
GetKPerBlock
(
a_block_desc
);
const
auto
m_per_block
=
GetMPerBlock
(
a_block_desc
);
const
auto
n_per_block
=
GetNPerBlock
(
b_block_desc
);
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
Kr
=
KPerB
lock
;
param
.
Kr
=
k_per_b
lock
;
param
.
lda
=
lda
;
param
.
ldb
=
ldb
;
param
.
ldc
=
ldc
;
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
static_ford
<
decltype
(
ordered_m_n_access_length
)
>
{}([
&
](
auto
ordered_idx
)
{
constexpr
auto
origin_m_n_idx
=
ordered_idx
.
ReorderGivenOld2New
(
ThreadMNAccessOrder
{});
constexpr
auto
current_m_idx
=
origin_m_n_idx
.
At
(
0
)
*
AThreadSliceLength
::
At
(
AThreadLoopOverDim
);
constexpr
auto
current_n_idx
=
origin_m_n_idx
.
At
(
1
)
*
BThreadSliceLength
::
At
(
BThreadLoopOverDim
);
constexpr
auto
current_mr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
0
)
-
current_m_idx
,
m_n_thread_length
.
At
(
0
));
constexpr
auto
current_nr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
1
)
-
current_n_idx
,
m_n_thread_length
.
At
(
1
));
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
{
auto
current_mr
=
ck
::
math
::
min
(
m_per_block
-
i_m
,
m_per_thread
);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
GetABlockStartOffset
(
a_block_desc
,
i_m
,
0
)];
constexpr
auto
a_block_idx
=
a_block_idx_zeros
.
Modify
(
AThreadLoopOverDim
,
current_m_idx
);
constexpr
auto
a_block_coord
=
make_tensor_coordinate
(
a_block_desc
,
to_multi_index
(
a_origin
+
a_block_idx
));
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr
auto
b_block_idx
=
b_block_idx_zeros
.
Modify
(
BThreadLoopOverDim
,
current_n_idx
);
constexpr
auto
b_block_coord
=
make_tensor_coordinate
(
b_block_desc
,
to_multi_index
(
b_origin
+
b_block_idx
));
for
(
ck
::
index_t
i_n
=
0
;
i_n
<
n_per_block
;
i_n
+=
n_per_thread
)
{
auto
current_nr
=
ck
::
math
::
min
(
n_per_block
-
i_n
,
n_per_thread
);
constexpr
auto
c_block_coord
=
make_tensor_coordinate
(
c_block_desc
,
to_multi_index
(
c_origin
+
origin_m_n_idx
))
;
param
.
p_b
=
&
b_block_buf
.
p_data_
[
GetBBlockStartOffset
(
b_block_desc
,
0
,
i_n
)];
param
.
p_c
=
&
c_buf
.
p_data_
[
GetCBlockStartOffset
(
c_desc
,
i_m
,
i_n
)]
;
param
.
p_a
=
&
a_block_buf
.
p_data_
[
a_block_coord
.
GetOffset
()];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
b_block_coord
.
GetOffset
()];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
c_block_coord
.
GetOffset
()];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
});
}
}
}
}
};
...
...
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
View file @
afc7d431
...
...
@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC
,
};
enum
ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop
,
NHWC_GemmKLoopOverC
,
// not merge c*y*x, and c % k_per_block == 0
};
enum
ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver
,
LoopOver_MNK
,
LoopOver_MKN
,
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
afc7d431
...
...
@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -23,20 +25,21 @@ namespace device {
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
>
// bool IsGemmMPadded
,
// bool IsGemmNPadded
,
// bool IsGemmKPadded
>
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
>
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
...
...
@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
ck
::
index_t
gemm_n_padded
=
math
::
integer_least_multiple
(
gemm_n
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
const
auto
wei_gemm_n_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
const
auto
wei_gemm_
n0_k_n1
_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemm_
padn_k
_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_n_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I0
)
/
make_tuple
(
make_right_pad_transform
(
gemm_n
,
gemm_n_padded
-
gemm_n
),
make_pass_through_transform
(
gemm_k
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
wei_gemm_n0_k_n1_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_padn_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I0
)
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_pass_through_transform
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I1
))),
ck
::
make_pass_through_transform
(
wei_gemm_
pad
n_k_grid_desc
.
GetLength
(
I1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
...
...
@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
...
...
@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
K
;
const
index_t
GemmN
=
GetGemmN
(
K
)
;
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
...
...
@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
bool
UseCLocalBuffer
=
true
;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
InDataType
,
InDataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
<
WeiDataType
,
WeiDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
OutDataType
,
OutDataType
,
CGridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
WeiDataType
,
// WeiDataType,
OutDataType
,
// OutDataType,
AccDataType
,
// AccDataType,
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
...
...
@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ck
::
Sequence
<
0
,
1
,
2
>
,
// BlockMNKAccessOrder,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
...
...
@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
InDataType
,
WeiDataType
,
...
...
@@ -591,7 +710,10 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation
,
CElementwiseOperation
>
;
float
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
float
ave_time
=
0
;
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
...
...
@@ -605,7 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
a
_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c
_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
arg
.
p_a_grid_
,
...
...
@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
...
...
@@ -749,15 +878,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
">"
;
<<
"DFwdAvx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
afc7d431
...
...
@@ -7,7 +7,9 @@
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <unistd.h>
namespace
ck
{
namespace
cpu
{
...
...
@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AccDataType
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
...
...
@@ -57,334 +58,92 @@ template <typename FloatA,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer. if false, will write to C directly
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
struct
GridwiseGemmAvx2_MxN
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
auto
GetABlockDescriptor
()
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
constexpr
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBloc
k
));
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_bl
k
));
return
a_block_desc_m_k
;
}
else
{
// A : K, M
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBloc
k
,
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_bl
k
,
math
::
integer_least_multiple
(
MPerBloc
k
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
m_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
return
a_block_desc_k_m
;
}
}
static
constexpr
auto
GetBBlockDescriptor
()
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
math
::
integer_least_multiple
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)));
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
return
b_block_desc_k_n
;
}
else
{
// B : N/8, K, N8
constexpr
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBloc
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBloc
k
,
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
static
constexpr
auto
GetABlockSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
Sequence
<
MPerBlock
,
KPerBlock
>
{};
}
else
{
// A : K, M
return
ck
::
Sequence
<
KPerBlock
,
MPerBlock
>
{};
}
}
static
constexpr
auto
GetBBlockSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
KPerBlock
,
NPerBlock
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
NPerBlock
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
>
{};
}
}
static
constexpr
auto
GetABlockDimAccessOrder
()
{
return
ck
::
Sequence
<
0
,
1
>
{};
}
static
constexpr
auto
GetBBlockDimAccessOrder
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
0
,
1
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
}
}
static
constexpr
auto
GetABlockMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
0
,
KPerBlock
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
KPerBlock
,
0
);
}
}
static
constexpr
auto
GetBBlockMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
KPerBlock
,
0
);
}
else
{
// B : N/8, K, N88;
return
ck
::
make_multi_index
(
0
,
KPerBlock
,
0
);
}
}
#if 0
static constexpr auto GetAThreadDiscriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, ck::tensor_layout::gemm::RowMajor>::value){
// A : M, K
constexpr auto a_thread_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock));
return a_thread_desc_m_k;
} else {
// A : K, M
constexpr auto a_thread_desc_k_m = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr));
return a_thread_desc_k_m;
}
}
static constexpr auto GetBThreadDescriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, ck::tensor_layout::gemm::RowMajor>::value){
// B : K, N
constexpr auto b_thread_desc_k_n = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr));
return b_thread_desc_k_n;
} else {
// B : N/8, K, N8
constexpr auto b_thread_desc_n_k_n8 = make_naive_tensor_descriptor_packed(make_tuple(math::integer_divide_ceil(ThreadwiseGemm_Dispatch::ThreadMaxNr, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_thread_desc_n_k_n8;
}
}
#endif
static
constexpr
auto
GetAThreadSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
,
KPerBlock
>
{};
}
else
{
// A : K, M
return
ck
::
Sequence
<
KPerBlock
,
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
>
{};
}
}
static
constexpr
auto
GetBThreadSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
KPerBlock
,
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
>
{};
}
}
static
constexpr
auto
GetAThreadMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
,
0
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
0
,
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
);
}
}
static
constexpr
auto
GetBThreadMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
0
,
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
);
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
0
,
0
>
{};
}
}
static
constexpr
ck
::
index_t
GetAThreadLoopOverDim
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
0
;
}
else
{
// A : K, M
return
1
;
}
}
static
constexpr
ck
::
index_t
GetBThreadLoopOverDim
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
1
;
}
else
{
// B : N/8, K, N88;
return
0
;
}
}
static
constexpr
auto
GetCBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// TODO:
}
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
static
constexpr
auto
GetCBlockSliceLength
()
{
return
ck
::
Sequence
<
MPerBlock
,
NPerBlock
>
{};
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
{
#if 0
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
is_valid
=
true
;
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
{
return false;
}
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
is_valid
&=
false
;
}
else
{
return false;
// TODO: need check c grid is simple transform?
if
(
GemmN
%
8
!=
0
)
is_valid
&=
false
;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
is_valid
;
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
...
...
@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
ck
::
index_t
m_per_block
;
ck
::
index_t
n_per_block
;
ck
::
index_t
k_per_block
;
ck
::
index_t
m_per_block
=
MPerBlock
;
ck
::
index_t
n_per_block
=
NPerBlock
;
ck
::
index_t
k_per_block
=
KPerBlock
;
if
constexpr
(
MPerBlock
==
0
&&
NPerBlock
==
0
&&
KPerBlock
==
0
)
{}
else
{
m_per_block
=
MPerBlock
;
n_per_block
=
NPerBlock
;
k_per_block
=
KPerBlock
;
}
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc
.
GetLength
(
I0
);
constexpr
auto
a_block_copy_dim
=
AGridDesc
::
GetNumOfDimension
();
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
M
,
m_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
N
,
n_per_block
);
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
constexpr
auto
a_block_desc
=
GetABlockDescriptor
();
constexpr
auto
a_block_slice_length
=
GetABlockSliceLength
();
constexpr
auto
a_block_copy_dim
=
decltype
(
a_block_slice_length
)
::
Size
();
constexpr
auto
a_dim_access_order
=
GetABlockDimAccessOrder
();
constexpr
auto
a_block_move_step
=
GetABlockMoveFwdStep
();
constexpr
auto
a_thread_slice_length
=
GetAThreadSliceLength
();
constexpr
auto
a_thread_loop_over_dim
=
GetAThreadLoopOverDim
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor
();
constexpr
auto
b_block_slice_length
=
GetBBlockSliceLength
();
constexpr
auto
b_block_copy_dim
=
decltype
(
b_block_slice_length
)
::
Size
();
constexpr
auto
b_dim_access_order
=
GetBBlockDimAccessOrder
();
constexpr
auto
b_block_move_step
=
GetBBlockMoveFwdStep
();
constexpr
auto
b_thread_slice_length
=
GetBThreadSliceLength
();
constexpr
auto
b_thread_loop_over_dim
=
GetBThreadLoopOverDim
();
constexpr
auto
c_block_desc
=
GetCBlockDescriptor
();
constexpr
auto
c_block_slice_length
=
GetCBlockSliceLength
();
constexpr
auto
c_block_move_step
=
ck
::
make_multi_index
(
0
,
NPerBlock
);
auto
a_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatA
,
// SrcData
FloatA
,
// DstData
decltype
(
a_grid_desc
),
// SrcDesc
decltype
(
a_block_desc
),
// DstDesc
AElementwiseOperation
,
// ElementwiseOperation
decltype
(
a_block_slice_length
),
// SliceLengths
decltype
(
a_dim_access_order
),
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
false
,
// SrcResetCoordinateAfterRun
true
// DstResetCoordinateAfterRun
>
(
a_grid_desc
,
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
a_block_desc
,
GetABlockDescriptor
(
m_per_block
,
k_per_block
)
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatB
,
// SrcData
FloatB
,
// DstData
decltype
(
b_grid_desc
),
// SrcDesc
decltype
(
b_block_desc
),
// DstDesc
BElementwiseOperation
,
// ElementwiseOperation
decltype
(
b_block_slice_length
),
// SliceLengths
decltype
(
b_dim_access_order
),
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
false
,
// SrcResetCoordinateAfterRun
true
// DstResetCoordinateAfterRun
>
(
b_grid_desc
,
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
b_block_desc
,
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
auto
c_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatC
,
// SrcData
FloatC
,
// DstData
decltype
(
c_block_desc
),
// SrcDesc
decltype
(
c_grid_desc
),
// DstDesc
BElementwiseOperation
,
// ElementwiseOperation
ck
::
Sequence
<
MPerBlock
,
NPerBlock
>
,
// SliceLengths
ck
::
Sequence
<
0
,
1
>
,
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
true
,
// SrcResetCoordinateAfterRun
false
// DstResetCoordinateAfterRun
>
(
c_block_desc
,
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
MPerBlock
*
KPerBlock
*
sizeof
(
FloatA
),
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
KPerBlock
*
NPerBlock
*
sizeof
(
FloatB
),
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
MPerBlock
*
NPerBlock
*
sizeof
(
FloatC
),
MemAlignmentByte
);
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
),
MemAlignmentByte
);
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum_t
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
),
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
));
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
:
c_grid_desc
.
GetElementSpaceSize
());
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
FloatA
,
// FloatA,
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
AccDataType
,
// AccDataType,
decltype
(
a_block_desc
),
// ABlockDesc,
decltype
(
b_block_desc
),
// BBlockDesc,
decltype
(
c_block_desc
),
// CBlockDesc,
decltype
(
a_block_slice_length
),
// ABlockSliceLengths,
decltype
(
b_block_slice_length
),
// BBlockSliceLengths,
decltype
(
c_block_slice_length
),
// CBlockSliceLengths,
decltype
(
a_thread_slice_length
),
// AThreadSliceLength,
decltype
(
b_thread_slice_length
),
// BThreadSliceLength,
a_thread_loop_over_dim
,
// AThreadLoopOverDim, // thread slice
// loop over on block slice. 1d is enough
// for now
b_thread_loop_over_dim
,
// BThreadLoopOverDim,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
,
0
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel for
for
(
ck
::
index_t
gid
=
0
;
gid
<
grid_size
;
gid
++
)
{
ck
::
index_t
i_mc
=
(
gid
/
grid_n
)
*
m_per_block
;
ck
::
index_t
i_nc
=
(
gid
%
grid_n
)
*
n_per_block
;
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
M
-
i_mc
,
m_per_block
);
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
N
-
i_nc
,
n_per_block
);
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
i_nc
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
));
// pack_b
b_threadwise_copy
.
RunGeneric
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_move_step
);
if
(
i_nc
==
0
)
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
}
else
{
// pack_a
a_threadwise_copy
.
RunGeneric
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_move_step
);
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
K
;
i_kc
+=
k_per_block
)
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
Gemm
K
;
i_kc
+=
k_per_block
)
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
K
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
// printf("A ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//}
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__);
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
...
...
@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
());
make_zero_multi_index
<
2
>
(),
i_kc
!=
0
);
// printf("[%d] 2222 \n",__LINE__);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
{
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
// printf("[%d] 2222 \n",__LINE__);
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
}
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
// i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
if
constexpr
(
UseCLocalBuffer
)
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
n_per_block
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
);
// only parallel in gemm m dim
#pragma omp parallel for
for
(
ck
::
index_t
i_mc
=
0
;
i_mc
<
GemmM
;
i_mc
+=
m_per_block
)
{
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
{
c_threadwise_copy
.
RunGeneric
(
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
// TODO: if use local C buffer, then this nc loop need to loop only once
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
{
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
!
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
MoveDstSliceWindow
(
c_grid_desc
,
c_block_move_step
);
}
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
}
if
((
i_kc
+
k_per_block
)
<
GemmK
)
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
}
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
afc7d431
...
...
@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_
gemm_
param.hpp"
namespace
ck
{
namespace
cpu
{
...
...
@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
"test %%edi, %%edi
\n
"
"je L_GemmAvx2_MxN_6x16_Store_C%=
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
...
...
@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_6x16_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
...
@@ -424,6 +428,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
// clang-format off
if
(
param
->
accmulate_c
){
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
...
...
@@ -436,6 +441,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
#pragma unroll
...
...
@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
if
constexpr
(
NonTemporalStore
)
{
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm1
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
...
...
@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
"test %%edi, %%edi
\n
"
"je L_GemmAvx2_MxN_4x24_Store_C%=
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Nr >16)
\n
vaddps 64(%%rax), %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm3, %%ymm3
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vaddps 32(%%rbx), %%ymm4, %%ymm4
\n
.endif
\n
"
".if (m_Mr > 1) && (m_Nr >16)
\n
vaddps 64(%%rbx), %%ymm5, %%ymm5
\n
.endif
\n
"
".if (m_Mr > 2)
\n
vaddps (%%rcx), %%ymm6, %%ymm6
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vaddps 32(%%rcx), %%ymm7, %%ymm7
\n
.endif
\n
"
".if (m_Mr > 2) && (m_Nr >16)
\n
vaddps 64(%%rcx), %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vaddps (%%rdx), %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vaddps 64(%%rdx), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_4x24_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
...
@@ -960,6 +984,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
// clang-format off
if
(
param
->
accmulate_c
)
{
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
...
...
@@ -972,6 +997,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
#pragma unroll
...
...
@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
6
][
2
]
=
{
{
ThreadwiseGemm_
6x16
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
1x16
_t
::
Run
,
},
{
ThreadwiseGemm_
5x16
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
2x16
_t
::
Run
,
},
{
ThreadwiseGemm_
4x16
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
3x16
_t
::
Run
,
},
{
ThreadwiseGemm_
3x16
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
4x16
_t
::
Run
,
},
{
ThreadwiseGemm_
2x16
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
5x16
_t
::
Run
,
},
{
ThreadwiseGemm_
1x16
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
6x16
_t
::
Run
,
},
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
5
&&
in
>=
0
&&
in
<=
1
);
return
dispatch_table
[
mr
][
nr
](
param
);
}
};
...
...
@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
4
][
3
]
=
{
{
ThreadwiseGemm_
4x24
_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
1x24
_t
::
Run
,
},
{
ThreadwiseGemm_
3x24
_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
2x24
_t
::
Run
,
},
{
ThreadwiseGemm_
2x24
_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
3x24
_t
::
Run
,
},
{
ThreadwiseGemm_
1x24
_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
4x24
_t
::
Run
,
},
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
return
dispatch_table
[
mr
][
nr
](
param
);
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
3
&&
in
>=
0
&&
in
<=
2
);
return
dispatch_table
[
im
][
in
](
param
);
}
};
...
...
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
→
include/ck/tensor_operation/cpu/thread/threadwise_
gemm_
param.hpp
View file @
afc7d431
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_
GEMM_
PARAM_HPP
#define CK_THREADWISE_
GEMM_
PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
...
...
@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t
ldb
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
float
alpha
;
uint32_t
_pack0
;
int
accmulate_c
;
// if 1, need load C and add into current fma. if 0, direct store out c result
}
__attribute__
((
packed
));
}
// namespace cpu
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
View file @
afc7d431
...
...
@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
int
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
int
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
int
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
int
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
int
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
int
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
int
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
int
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
printf
(
"N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d
\n
"
,
N
,
Hi
,
Wi
,
C
,
Ho
,
Wo
,
Fy
,
Fx
,
Dy
,
Sy
,
Dx
,
Sx
,
Py
,
Px
);
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
std
::
cout
<<
"src hidden:"
<<
SrcDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
std
::
cout
<<
"dst hidden:"
<<
DstDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
...
...
@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
#endif
const
auto
src_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
src_slice_step
=
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
const
auto
dst_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
dst_slice_step
=
make_tensor_coordinate_step
(
dst_desc
,
to_multi_index
(
dst_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
for
(
auto
idx_id
=
0
;
idx_id
<
num_access
;
idx_id
++
)
{
using
src_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
printf
(
"[%s] "
,
is_src_valid
?
"y"
:
"n"
);
print_multi_index
(
src_coord_
.
GetIndex
());
printf
(
"----"
);
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_
(
dst_vector_container
.
template
AsType
<
dst_vector_t
>(),
src_vector_container
.
template
AsType
<
src_vector_t
>());
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
printf
(
" -> "
);
print_multi_index
(
dst_coord_
.
GetIndex
());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf
(
"
\n
"
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>());
// move coordinate
if
(
idx_id
!=
num_access
-
1
)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_slice_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_slice_step
);
}
}
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
0 → 100644
View file @
afc7d431
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp"
#include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
namespace
ck
{
namespace
cpu
{
namespace
avx2_util
{
inline
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
_mm256_loadu_ps
(
p_src
+
0
));
_mm256_storeu_ps
(
p_dst
+
8
,
_mm256_loadu_ps
(
p_src
+
8
));
p_dst
+=
16
;
p_src
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
_mm256_loadu_ps
(
p_src
));
p_dst
+=
8
;
p_src
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
_mm_loadu_ps
(
p_src
));
p_dst
+=
4
;
p_src
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
_mm_loadu_si64
(
p_src
));
p_dst
+=
2
;
p_src
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
*
p_src
;
}
}
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
__m256
ymm
=
_mm256_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
__m128
xmm
=
_mm_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
ymm
);
_mm256_storeu_ps
(
p_dst
+
8
,
ymm
);
p_dst
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
ymm
);
p_dst
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
xmm
);
p_dst
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
xmm
);
p_dst
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
*
reinterpret_cast
<
const
float
*>
(
&
value
);
}
}
inline
void
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
)
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
__m256
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
r0
=
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
);
r1
=
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
);
r2
=
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
);
r3
=
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
);
r4
=
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
);
r5
=
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
);
r6
=
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
);
r7
=
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
);
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
t2
=
_mm256_unpacklo_ps
(
r2
,
r3
);
t3
=
_mm256_unpackhi_ps
(
r2
,
r3
);
t4
=
_mm256_unpacklo_ps
(
r4
,
r5
);
t5
=
_mm256_unpackhi_ps
(
r4
,
r5
);
t6
=
_mm256_unpacklo_ps
(
r6
,
r7
);
t7
=
_mm256_unpackhi_ps
(
r6
,
r7
);
r0
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r1
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r2
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r3
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r4
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r5
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r6
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r7
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
t0
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x20
);
t1
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x20
);
t2
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x20
);
t3
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x20
);
t4
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x31
);
t5
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x31
);
t6
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x31
);
t7
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x31
);
_mm256_storeu_ps
(
p_dst
+
0
*
stride_dst
,
t0
);
_mm256_storeu_ps
(
p_dst
+
1
*
stride_dst
,
t1
);
_mm256_storeu_ps
(
p_dst
+
2
*
stride_dst
,
t2
);
_mm256_storeu_ps
(
p_dst
+
3
*
stride_dst
,
t3
);
_mm256_storeu_ps
(
p_dst
+
4
*
stride_dst
,
t4
);
_mm256_storeu_ps
(
p_dst
+
5
*
stride_dst
,
t5
);
_mm256_storeu_ps
(
p_dst
+
6
*
stride_dst
,
t6
);
_mm256_storeu_ps
(
p_dst
+
7
*
stride_dst
,
t7
);
}
}
// namespace avx2_util
using
ConvolutionForwardSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
;
using
ConvolutionForwardGemmKSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
;
// assume input -> a matrix
// assume input -> MC * KC
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
N
=
1
;
Hi
=
1
;
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
// gemm_m
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// gemm_k
Ho
=
1
;
Wo
=
Wi
;
Fy
=
1
;
Fx
=
1
;
Dy
=
1
;
Sy
=
1
;
Dx
=
1
;
Sx
=
1
;
Py
=
0
;
Px
=
0
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Fy
=
1
;
Fx
=
1
;
Dy
=
1
;
Sy
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
coefficients_
[
Number
<
0
>
{}];
Dx
=
1
;
Sx
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
coefficients_
[
Number
<
0
>
{}];
Py
=
0
;
Px
=
0
;
}
else
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
}
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi
=
Sx
*
C
;
input_offset_ovf_wi_acc_hi
=
Sy
*
Wi
*
C
-
Wo
*
Sx
*
C
;
input_offset_ovf_hi_acc_n
=
Hi
*
Wi
*
C
-
Ho
*
Sy
*
Wi
*
C
;
// input_offset_acc_c = 1;
input_offset_ovf_c_acc_x
=
Dx
*
C
-
C
;
input_offset_ovf_x_acc_y
=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
src_offset
=
-
Py
*
Wi
*
C
-
Px
*
C
;
i_n
=
0
;
i_c
=
0
;
i_hi
=
-
Py
;
i_wi
=
-
Px
;
i_ho
=
0
;
i_wo
=
0
;
i_y
=
0
;
i_x
=
0
;
i_gemm_k
=
0
;
#if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
#endif
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
i_wi
=
idx_m
;
i_c
=
idx_k
;
src_offset
=
i_wi
*
C
+
i_c
;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c
=
idx_k
;
i_x
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
;
i_wi
=
i_wo
*
Sx
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
}
else
{
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
(
idx_k
==
0
)
{
i_c
=
0
;
i_x
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
-
Py
;
i_wi
=
i_wo
*
Sx
-
Px
;
}
else
{
i_c
=
idx_k
%
C
;
i_x
=
(
idx_k
/
C
)
%
Fx
;
i_y
=
(
idx_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
}
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
p_src
;
}
else
{
const
ck
::
index_t
m_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block);
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
ck
::
index_t
i_m_itr
=
m_per_block
;
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
);
i_m_itr
-=
8
;
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
C
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
p_dst
+=
4
*
k_per_block
;
p_src
+=
4
*
C
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+=
2
*
k_per_block
;
p_src
+=
2
*
C
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
while
(
i_m_itr
>
0
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
p_dst
+=
k_per_block
;
i_wo_itr
++
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_ho_itr
++
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
}
else
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
// c % k_per_block == 0, so every time k_per_block here is the same
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while
(
i_m_itr
>
0
)
{
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
else
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
// printf("[%d] \n", __LINE__);
}
else
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while
(
i_m_itr
>
0
)
{
/*** go along Gemm K ***/
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
ck
::
index_t
i_wi_itr_k
=
i_wi_itr
;
ck
::
index_t
i_hi_itr_k
=
i_hi_itr
;
ck
::
index_t
i_c_itr_k
=
i_c
;
ck
::
index_t
i_y_itr_k
=
i_y
;
ck
::
index_t
i_x_itr_k
=
i_x
;
ck
::
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>
0
)
{
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
);
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
p_dst_k
+=
current_k_block
;
p_src_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
;
if
(
i_c_itr_k
>=
C
)
{
i_c_itr_k
=
0
;
i_x_itr_k
++
;
i_wi_itr_k
+=
Dx
;
p_src_k
+=
input_offset_ovf_c_acc_x
;
}
if
(
i_x_itr_k
>=
Fx
)
{
i_x_itr_k
=
0
;
i_y_itr_k
++
;
i_hi_itr_k
+=
Dy
;
p_src_k
+=
input_offset_ovf_x_acc_y
;
}
i_k_itr
-=
current_k_block
;
}
/*** go along Gemm K ***/
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
}
}
}
}
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
i_c
+=
move_k
;
src_offset
+=
move_k
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
i_c
+=
move_k
;
src_offset
+=
move_k
;
}
else
{
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
// c % k_per_block == 0, so every time k_per_block here is the same
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
i_c
+=
move_k
;
src_offset
+=
move_k
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if
(
i_c
>=
C
)
{
i_c
=
0
;
i_x
++
;
i_wi
+=
Dx
;
src_offset
+=
Dx
*
C
-
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
if
(
i_x
>=
Fx
)
{
i_x
=
0
;
i_y
++
;
i_wi
=
i_wi
-
Fx
*
Dx
;
i_hi
+=
Dy
;
src_offset
+=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
}
else
{
i_gemm_k
+=
move_k
;
i_c
=
i_gemm_k
%
C
;
i_x
=
(
i_gemm_k
/
C
)
%
Fx
;
i_y
=
(
i_gemm_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
}
}
}
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_n
;
ck
::
index_t
i_c
;
ck
::
index_t
i_hi
;
ck
::
index_t
i_wi
;
ck
::
index_t
i_ho
;
ck
::
index_t
i_wo
;
ck
::
index_t
i_y
;
ck
::
index_t
i_x
;
ck
::
index_t
i_gemm_k
;
ck
::
index_t
N
;
// ck::index_t K;
ck
::
index_t
C
;
ck
::
index_t
Hi
;
ck
::
index_t
Wi
;
ck
::
index_t
Ho
;
ck
::
index_t
Wo
;
ck
::
index_t
Sy
;
ck
::
index_t
Sx
;
ck
::
index_t
Dy
;
ck
::
index_t
Dx
;
ck
::
index_t
Py
;
ck
::
index_t
Px
;
ck
::
index_t
Fy
;
ck
::
index_t
Fx
;
intptr_t
input_offset_acc_wi
;
intptr_t
input_offset_ovf_wi_acc_hi
;
intptr_t
input_offset_ovf_hi_acc_n
;
// intptr_t input_offset_acc_c;
intptr_t
input_offset_ovf_c_acc_x
;
intptr_t
input_offset_ovf_x_acc_y
;
intptr_t
src_offset
;
// keep this as pointer type in case we have negative offset
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
GemmN1
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_n0
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_n1
=
src_slice_origin_idx
[
Number
<
2
>
{}];
i_gemm_n
=
idx_n0
*
GemmN1
+
idx_n1
;
// i_gemm_k = idx_k;
src_offset
=
idx_n0
*
GemmK
*
GemmN1
+
idx_k
+
idx_n1
*
GemmN1
;
// Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
// TODO: weight NHWC not support this
}
else
{
const
ck
::
index_t
n_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for
(
index_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
{
ck
::
index_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
8
);
ck
::
index_t
i_k_itr
=
k_per_block
;
if
(
current_n_8
==
8
)
{
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
while
(
i_k_itr
>=
8
)
{
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
);
p_dst_k
+=
8
*
8
;
p_src_k
+=
8
;
i_k_itr
-=
8
;
}
if
(
i_k_itr
&
4
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
2
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
2
];
p_dst_k
[
3
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
3
];
p_dst_k
+=
4
*
8
;
p_src_k
+=
4
;
}
if
(
i_k_itr
&
2
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
+=
2
*
8
;
p_src_k
+=
2
;
}
if
(
i_k_itr
&
1
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
}
}
else
{
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
for
(
index_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
{
for
(
index_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
{
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
float
v
=
i_current_n_itr
<
GemmN
?
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
]
:
.0
f
;
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
}
}
}
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
GemmK
;
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
ck
::
index_t
move_n0
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset
+=
move_k
+
move_n0
*
GemmK
*
GemmN1
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_gemm_n
;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck
::
index_t
GemmN1
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
dst_desc
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
DstGemmM
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
DstGemmN
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
src_offset
=
0
;
dst_offset
=
0
;
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
if
constexpr
(
BypassTransfer
)
{
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
i_src_gemm_m
*
DstGemmN
+
i_src_gemm_n
;
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
dst_slice_origin_idx
)
{
i_dst_gemm_m
=
dst_slice_origin_idx
[
Number
<
0
>
{}];
i_dst_gemm_n
=
dst_slice_origin_idx
[
Number
<
1
>
{}];
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
m_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
const
ck
::
index_t
n_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_src
+=
8
*
n_per_block
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_dst_gemm_m
;
ck
::
index_t
i_dst_gemm_n
;
ck
::
index_t
DstGemmM
;
ck
::
index_t
DstGemmN
;
intptr_t
src_offset
;
intptr_t
dst_offset
;
};
}
// namespace cpu
}
// namespace ck
#endif
library/include/ck/library/host_tensor/device.hpp
View file @
afc7d431
...
...
@@ -121,6 +121,10 @@ template <typename... Args, typename F>
float
launch_and_time_cpu_kernel
(
F
kernel
,
int
nrepeat
,
Args
...
args
)
{
WallTimer
timer
;
int
nwarmup
=
3
;
for
(
int
i
=
0
;
i
<
nwarmup
;
i
++
)
kernel
(
args
...);
timer
.
Start
();
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
afc7d431
...
...
@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
using
P
assThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
P
T
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
...
...
@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
256
,
128
,
64
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
512
,
256
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
1024
,
144
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>>
;
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
120
,
64
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
144
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
288
,
128
,
4
,
24
,
true
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
afc7d431
...
...
@@ -9,8 +9,11 @@
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -18,6 +21,8 @@ namespace cpu {
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
...
...
@@ -34,6 +39,7 @@ using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
int
error_count
=
0
;
float
max_diff
=
1e-6
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
...
...
@@ -41,28 +47,35 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
{
return
false
;
error_count
++
;
printf
(
"idx:%3d, ref:%f, res:%f (diff:%f)
\n
"
,
i
,
double
(
ref
.
mData
[
i
]),
double
(
result
.
mData
[
i
]),
diff
);
}
}
return
true
;
return
error_count
==
0
;
}
float
calculate_gflops
()
{}
int
main
(
int
argc
,
char
*
argv
[])
{
int
data_type
=
0
;
int
init_method
=
0
;
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
2
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
conv_stride_h
=
2
;
ck
::
index_t
conv_stride_w
=
2
;
ck
::
index_t
conv_stride_h
=
1
;
ck
::
index_t
conv_stride_w
=
1
;
ck
::
index_t
conv_dilation_h
=
1
;
ck
::
index_t
conv_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
...
...
@@ -72,7 +85,7 @@ int main(int argc, char* argv[])
if
(
argc
==
1
)
{
data_type
=
1
;
data_type
=
0
;
init_method
=
1
;
}
else
if
(
argc
==
3
)
...
...
@@ -110,17 +123,14 @@ int main(int argc, char* argv[])
exit
(
1
);
}
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
,
auto
acc_type
)
{
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
ReferenceConvBwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
using
ReferenceConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
...
...
@@ -147,53 +157,93 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
};
Tensor
<
Out
DataType
>
out_n_ho_wo_k
(
f_host_tensor_descriptor
(
N
,
K
,
H
o
,
W
o
));
Tensor
<
WeiDataType
>
wei_k_y_x
_c
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
In
DataType
>
in_n
_h
i
_w
i_c
_host_result
(
f_host_tensor_descriptor
(
N
,
C
,
H
i
,
W
i
));
Tensor
<
In
DataType
>
in_n
_h
i
_w
i_c
_device_result
(
f_host_tensor_descriptor
(
N
,
C
,
H
i
,
W
i
));
Tensor
<
In
DataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
H
i
,
W
i
));
Tensor
<
WeiDataType
>
wei_k_
c_
y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
Out
DataType
>
out_n_k
_h
o
_w
o
_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
H
o
,
W
o
));
Tensor
<
Out
DataType
>
out_n_k
_h
o
_w
o
_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
H
o
,
W
o
));
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_hi_wi_c_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_y_x_c
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_ho_wo_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"LPad(H, W):"
<<
in_left_pad_h
<<
","
<<
in_left_pad_w
<<
", RPad(H, W):"
<<
in_right_pad_h
<<
","
<<
in_right_pad_w
<<
", Stride(H, W):"
<<
conv_stride_h
<<
", "
<<
conv_stride_w
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break
;
case
2
:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{});
break
;
case
3
:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for
(
auto
i_n
=
0
;
i_n
<
N
;
i_n
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
{
for
(
auto
i_hi
=
0
;
i_hi
<
Hi
;
i_hi
++
)
{
for
(
auto
i_wi
=
0
;
i_wi
<
Wi
;
i_wi
++
)
{
uint32_t
v
=
PACK_32
(
i_n
,
i_c
,
i_hi
,
i_wi
);
in_n_c_hi_wi
(
i_n
,
i_c
,
i_hi
,
i_wi
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
}
}
}
}
for
(
auto
i_k
=
0
;
i_k
<
K
;
i_k
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
{
for
(
auto
i_y
=
0
;
i_y
<
Y
;
i_y
++
)
{
for
(
auto
i_x
=
0
;
i_x
<
X
;
i_x
++
)
{
uint32_t
v
=
PACK_32
(
i_k
,
i_c
,
i_y
,
i_x
);
wei_k_c_y_x
(
i_k
,
i_c
,
i_y
,
i_x
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
}
}
}
}
break
;
default:
out_n
_h
o
_w
o_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
Out
DataType
>
{
1
});
wei_k_y_x
_c
.
GenerateTensorValue
(
GeneratorTensor_
1
<
WeiDataType
>
{
1
});
in_n_c
_h
i
_w
i
.
GenerateTensorValue
(
GeneratorTensor_
3
<
In
DataType
>
{
0
,
1
});
wei_k_
c_
y_x
.
GenerateTensorValue
(
GeneratorTensor_
3
<
WeiDataType
>
{
-
1
,
1
});
}
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_hi_wi_c_device_result
.
mDesc
.
GetElementSpace
(),
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_host_result
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
out_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
// reset input to zero
in_n_hi_wi_c_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
0
});
in_device_buf
.
ToDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
// get host result
{
auto
ref_conv
=
ReferenceConvFwdInstance
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_hi_wi
_c_host_result
,
wei_k_y_x
_c
,
out_n_ho_wo_
k
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_
c_
hi_wi
,
wei_k_
c_
y_x
,
out_n_
k_
ho_wo_
host_result
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
@@ -205,8 +255,8 @@ int main(int argc, char* argv[])
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
...
...
@@ -226,6 +276,9 @@ int main(int argc, char* argv[])
// profile device Conv instances
bool
success
=
true
;
double
fastest_kernel_time
=
std
::
numeric_limits
<
double
>::
max
();
std
::
string
fastest_kernel_name
=
""
;
double
fastest_kernel_gflops
=
0
;
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
...
...
@@ -249,18 +302,30 @@ int main(int argc, char* argv[])
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
1
);
double
time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
10
);
double
total_flop
=
static_cast
<
double
>
(
2
)
*
N
*
C
*
Ho
*
Wo
*
K
*
Y
*
X
;
in_device_buf
.
FromDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
double
gflops
=
(
total_flop
*
1e-6
)
/
time
;
if
(
!
check_out
(
in_n_hi_wi_c_host_result
,
in_n_hi_wi_c_device_result
))
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
if
(
!
check_out
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
))
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
}
else
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
", Time:"
<<
time
<<
"ms, Gflops:"
<<
gflops
<<
std
::
endl
;
if
(
time
<
fastest_kernel_time
)
{
fastest_kernel_time
=
time
;
fastest_kernel_name
=
conv_ptr
->
GetTypeString
();
fastest_kernel_gflops
=
gflops
;
}
}
}
else
...
...
@@ -269,25 +334,27 @@ int main(int argc, char* argv[])
}
}
if
(
success
)
if
(
fastest_kernel_time
!=
std
::
numeric_limits
<
double
>::
max
())
{
std
::
cout
<<
"test conv2d fwd cpu : Pass"
<<
std
::
endl
;
return
0
;
}
else
{
std
::
cout
<<
"test conv2d fwd cpu: Fail "
<<
std
::
endl
;
return
-
1
;
std
::
cout
<<
" fastest:"
<<
fastest_kernel_name
<<
", time:"
<<
fastest_kernel_time
<<
"ms, Gflops:"
<<
fastest_kernel_gflops
<<
std
::
endl
;
}
return
0
;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if
(
data_type
==
0
)
{
return
Run
(
F32
(),
F32
(),
F32
(),
F32
());
}
else
if
(
data_type
==
1
)
{
return
Run
(
F16
(),
F16
(),
F16
(),
F32
());
return
Run
(
F32
(),
F32
(),
F32
());
}
else
{
...
...
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
View file @
afc7d431
...
...
@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static
constexpr
ck
::
index_t
nDim
=
ck
::
remove_reference_t
<
decltype
(
input_desc
)
>::
GetNumOfDimension
();
input_desc
.
Print
();
auto
threadwise_transfer
=
threadwise_transfer_t
{
input_desc
,
ck
::
make_zero_multi_index
<
nDim
>
(),
input_cblock_desc
,
...
...
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
afc7d431
...
...
@@ -321,6 +321,7 @@ void test_ukernel(ukenrel_t uk,
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
accmulate_c
=
0
;
memset
(
private_c
,
0
,
m
*
n
*
sizeof
(
float
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment