Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_onnxruntime
Commits
afc7d431
"vscode:/vscode.git/clone" did not exist on "d799084a9a7d5196ae708da72b46dd8e84604194"
Commit
afc7d431
authored
Apr 24, 2022
by
carlushuang
Browse files
avx2 gemm now works for single thread
parent
07af8343
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2308 additions
and
962 deletions
+2308
-962
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+160
-85
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
...ion/cpu/device/convolution_forward_specialization_cpu.hpp
+13
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+174
-36
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+248
-424
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+108
-63
include/ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
.../ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
+3
-3
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
...tion/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
+109
-0
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+1084
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+5
-1
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+30
-46
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+363
-296
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
+2
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+9
-8
No files found.
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
afc7d431
...
@@ -13,21 +13,10 @@ namespace cpu {
...
@@ -13,21 +13,10 @@ namespace cpu {
template
<
typename
FloatA
,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AccDataType
,
typename
ABlockDesc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
BBlockDesc
,
typename
CBlockDesc
,
typename
CDesc
,
typename
ABlockSliceLengths
,
typename
BBlockSliceLengths
,
typename
CBlockSliceLengths
,
typename
AThreadSliceLength
,
typename
BThreadSliceLength
,
ck
::
index_t
AThreadLoopOverDim
,
// thread slice loop over on block slice. 1d is enough for
// now
ck
::
index_t
BThreadLoopOverDim
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
...
@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static
constexpr
index_t
nDimA
=
ABlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimA
=
ABlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimB
=
BBlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimB
=
BBlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
C
Block
Desc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
CDesc
::
GetNumOfDimension
();
using
IndexA
=
MultiIndex
<
nDimA
>
;
using
IndexA
=
MultiIndex
<
nDimA
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CBlockDesc
{},
IndexC
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
constexpr
auto
GetLeadingElement
(
const
TensorDesc
&
desc
)
constexpr
auto
GetLeadingElement
(
const
TensorDesc
&
desc
)
...
@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBlockBuffer
>
ck
::
index_t
GetALeadingElement
(
const
ABlockDesc
&
a_block_desc
)
const
{
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetBLeadingElement
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetCLeadingElement
(
const
CDesc
&
c_desc
)
const
{
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetMPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetKPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
}
ck
::
index_t
GetNPerBlock
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
i_m
*
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
return
i_m
;
}
}
ck
::
index_t
GetBBlockStartOffset
(
const
BBlockDesc
&
b_block_desc
,
const
index_t
,
const
index_t
i_n
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
i_n
;
}
else
{
// N/8 * K * 8
return
i_n
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetCBlockStartOffset
(
const
CDesc
&
c_desc
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
return
i_m
*
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBuffer
>
void
Run
(
const
ABlockDesc
&
a_block_desc
,
void
Run
(
const
ABlockDesc
&
a_block_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
ABlockBuffer
&
a_block_buf
,
const
IndexA
&
a_origin
,
const
IndexA
&
/*
a_origin
*/
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
IndexB
&
b_origin
,
const
IndexB
&
/* b_origin */
,
const
CDesc
&
c_desc
,
CBuffer
&
c_buf
,
const
IndexC
&
/* c_origin */
,
const
CBlockDesc
&
c_block_desc
,
bool
is_accumulate_c
=
true
)
const
CBlockBuffer
&
c_block_buf
,
const
IndexC
&
c_origin
)
const
{
{
auto
lda
=
GetALeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
constexpr
auto
m_n_block_length
=
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
ck
::
Sequence
<
ABlockSliceLengths
::
At
(
AThreadLoopOverDim
),
BBlockSliceLengths
::
At
(
BThreadLoopOverDim
)
>
{};
constexpr
auto
m_n_thread_length
=
ck
::
Sequence
<
AThreadSliceLength
::
At
(
AThreadLoopOverDim
),
BThreadSliceLength
::
At
(
BThreadLoopOverDim
)
>
{};
constexpr
auto
m_n_access_length
=
m_n_block_length
/
m_n_thread_length
;
const
auto
k_per_block
=
GetKPerBlock
(
a_block_desc
);
const
auto
m_per_block
=
GetMPerBlock
(
a_block_desc
);
const
auto
n_per_block
=
GetNPerBlock
(
b_block_desc
);
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
constexpr
auto
ordered_m_n_access_length
=
ck
::
cpu
::
ThreadwiseGemmParam
param
;
container_reorder_given_new2old
(
m_n_access_length
,
ThreadMNAccessOrder
{});
param
.
Kr
=
k_per_block
;
param
.
lda
=
lda
;
param
.
ldb
=
ldb
;
param
.
ldc
=
ldc
;
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
{
auto
current_mr
=
ck
::
math
::
min
(
m_per_block
-
i_m
,
m_per_thread
);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
GetABlockStartOffset
(
a_block_desc
,
i_m
,
0
)];
constexpr
auto
a_block_idx_zeros
=
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
typename
uniform_sequence_gen
<
nDimA
,
0
>::
type
{};
// starting point of the block
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr
auto
b_block_idx_zeros
=
typename
uniform_sequence_gen
<
nDimB
,
0
>::
type
{};
constexpr
auto
lda
=
GetLeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
for
(
ck
::
index_t
i_n
=
0
;
i_n
<
n_per_block
;
i_n
+=
n_per_thread
)
constexpr
auto
ldb
=
GetLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
{
constexpr
auto
ldc
=
GetLeadingElement
(
c_block_desc
)
*
sizeof
(
FloatC
);
auto
current_nr
=
ck
::
math
::
min
(
n_per_block
-
i_n
,
n_per_thread
);
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
p_b
=
&
b_block_buf
.
p_data_
[
GetBBlockStartOffset
(
b_block_desc
,
0
,
i_n
)];
param
.
Kr
=
KPerBlock
;
param
.
p_c
=
&
c_buf
.
p_data_
[
GetCBlockStartOffset
(
c_desc
,
i_m
,
i_n
)];
param
.
lda
=
lda
;
param
.
ldb
=
ldb
;
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
param
.
ldc
=
ldc
;
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
param
.
alpha
=
1.0
f
;
// TODO
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
static_ford
<
decltype
(
ordered_m_n_access_length
)
>
{}([
&
](
auto
ordered_idx
)
{
constexpr
auto
origin_m_n_idx
=
ordered_idx
.
ReorderGivenOld2New
(
ThreadMNAccessOrder
{});
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
}
constexpr
auto
current_m_idx
=
}
origin_m_n_idx
.
At
(
0
)
*
AThreadSliceLength
::
At
(
AThreadLoopOverDim
);
}
constexpr
auto
current_n_idx
=
origin_m_n_idx
.
At
(
1
)
*
BThreadSliceLength
::
At
(
BThreadLoopOverDim
);
constexpr
auto
current_mr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
0
)
-
current_m_idx
,
m_n_thread_length
.
At
(
0
));
constexpr
auto
current_nr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
1
)
-
current_n_idx
,
m_n_thread_length
.
At
(
1
));
constexpr
auto
a_block_idx
=
a_block_idx_zeros
.
Modify
(
AThreadLoopOverDim
,
current_m_idx
);
constexpr
auto
a_block_coord
=
make_tensor_coordinate
(
a_block_desc
,
to_multi_index
(
a_origin
+
a_block_idx
));
constexpr
auto
b_block_idx
=
b_block_idx_zeros
.
Modify
(
BThreadLoopOverDim
,
current_n_idx
);
constexpr
auto
b_block_coord
=
make_tensor_coordinate
(
b_block_desc
,
to_multi_index
(
b_origin
+
b_block_idx
));
constexpr
auto
c_block_coord
=
make_tensor_coordinate
(
c_block_desc
,
to_multi_index
(
c_origin
+
origin_m_n_idx
));
param
.
p_a
=
&
a_block_buf
.
p_data_
[
a_block_coord
.
GetOffset
()];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
b_block_coord
.
GetOffset
()];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
c_block_coord
.
GetOffset
()];
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
});
}
}
};
};
...
...
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
View file @
afc7d431
...
@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
...
@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC
,
OddC
,
};
};
enum
ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop
,
NHWC_GemmKLoopOverC
,
// not merge c*y*x, and c % k_per_block == 0
};
enum
ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver
,
LoopOver_MNK
,
LoopOver_MKN
,
};
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
afc7d431
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -23,20 +25,21 @@ namespace device {
...
@@ -23,20 +25,21 @@ namespace device {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
>
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
// bool IsGemmMPadded
,
bool
UseALocalBuffer
,
// bool IsGemmNPadded
,
bool
UseBLocalBuffer
,
// bool IsGemmKPadded
>
bool
UseCLocalBuffer
>
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
{
...
@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
ck
::
index_t
gemm_n_padded
=
math
::
integer_least_multiple
(
gemm_n
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
const
auto
wei_gemm_n_k_grid_desc
=
const
auto
wei_gemm_n_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
const
auto
wei_gemm_
n0_k_n1
_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemm_
padn_k
_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_n_k_grid_desc
,
wei_gemm_n_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
make_tuple
(
make_right_pad_transform
(
gemm_n
,
gemm_n_padded
-
gemm_n
),
ck
::
make_tuple
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I0
)
/
make_pass_through_transform
(
gemm_k
)),
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_pass_through_transform
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I1
))),
const
auto
wei_gemm_n0_k_n1_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_padn_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I0
)
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_pass_through_transform
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
...
@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
multiplies
<
ck
::
index_t
>
());
std
::
multiplies
<
ck
::
index_t
>
());
}
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
namespace
ck
;
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
K
;
const
index_t
GemmN
=
GetGemmN
(
K
)
;
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
// A:
...
@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
bool
UseCLocalBuffer
=
true
;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
InDataType
,
InDataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
<
WeiDataType
,
WeiDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
OutDataType
,
OutDataType
,
CGridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
WeiDataType
,
// WeiDataType,
WeiDataType
,
// WeiDataType,
OutDataType
,
// OutDataType,
OutDataType
,
// OutDataType,
AccDataType
,
// AccDataType,
AGridDesc
,
// AGridDesc,
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
CGridDesc
,
// CGridDesc,
...
@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock
,
// NPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ck
::
Sequence
<
0
,
1
,
2
>
,
// BlockMNKAccessOrder,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
...
@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
...
@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
CElementwiseOperation
>
;
float
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
float
ave_time
=
0
;
nrepeat
,
arg
.
p_a_grid_
,
if
(
nrepeat
!=
1
)
arg
.
p_b_grid_
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
arg
.
p_c_grid_
,
nrepeat
,
arg
.
a_grid_desc_
,
arg
.
p_a_grid_
,
arg
.
b_grid_desc_
,
arg
.
p_b_grid_
,
arg
.
c_grid_desc_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
a_grid_desc_
,
arg
.
b_element_op_
,
arg
.
b_grid_desc_
,
arg
.
c_element_op_
);
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
a
_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c
_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
...
@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
}
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
// Gridwise GEMM size
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
...
@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"DFwdAvx2_NHWC_KYXC"
<<
"<"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
MPerBlock
<<
", "
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
NPerBlock
<<
", "
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
KPerBlock
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
">"
;
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
afc7d431
...
@@ -7,7 +7,9 @@
...
@@ -7,7 +7,9 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <unistd.h>
namespace
ck
{
namespace
ck
{
namespace
cpu
{
namespace
cpu
{
...
@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
...
@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
template
<
typename
FloatA
,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AccDataType
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
typename
CGridDesc
,
...
@@ -57,334 +58,92 @@ template <typename FloatA,
...
@@ -57,334 +58,92 @@ template <typename FloatA,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer. if false, will write to C directly
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
>
struct
GridwiseGemmAvx2_MxN
struct
GridwiseGemmAvx2_MxN
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
auto
GetABlockDescriptor
()
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
// A : M, K
// A : M, K
constexpr
auto
a_block_desc_m_k
=
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBloc
k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_bl
k
));
return
a_block_desc_m_k
;
return
a_block_desc_m_k
;
}
}
else
else
{
{
// A : K, M
// A : K, M
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBloc
k
,
make_tuple
(
k_per_bl
k
,
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
MPerBloc
k
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
m_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
return
a_block_desc_k_m
;
return
a_block_desc_k_m
;
}
}
}
}
static
constexpr
auto
GetBBlockDescriptor
()
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
// B : K, N
// B : K, N
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
auto
b_block_desc_k_n
=
make_tuple
(
KPerBlock
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
math
::
integer_least_multiple
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)));
return
b_block_desc_k_n
;
return
b_block_desc_k_n
;
}
}
else
else
{
{
// B : N/8, K, N8
// B : N/8, K, N8
constexpr
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBloc
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
math
::
integer_divide_ceil
(
n_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBloc
k
,
k_per_bl
k
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
return
b_block_desc_n0_k_n1
;
}
}
}
}
static
constexpr
auto
GetABlockSliceLength
()
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
Sequence
<
MPerBlock
,
KPerBlock
>
{};
}
else
{
// A : K, M
return
ck
::
Sequence
<
KPerBlock
,
MPerBlock
>
{};
}
}
static
constexpr
auto
GetBBlockSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
KPerBlock
,
NPerBlock
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
NPerBlock
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
>
{};
}
}
static
constexpr
auto
GetABlockDimAccessOrder
()
{
return
ck
::
Sequence
<
0
,
1
>
{};
}
static
constexpr
auto
GetBBlockDimAccessOrder
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
0
,
1
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
}
}
static
constexpr
auto
GetABlockMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
0
,
KPerBlock
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
KPerBlock
,
0
);
}
}
static
constexpr
auto
GetBBlockMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
KPerBlock
,
0
);
}
else
{
// B : N/8, K, N88;
return
ck
::
make_multi_index
(
0
,
KPerBlock
,
0
);
}
}
#if 0
static constexpr auto GetAThreadDiscriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, ck::tensor_layout::gemm::RowMajor>::value){
// A : M, K
constexpr auto a_thread_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock));
return a_thread_desc_m_k;
} else {
// A : K, M
constexpr auto a_thread_desc_k_m = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr));
return a_thread_desc_k_m;
}
}
static constexpr auto GetBThreadDescriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, ck::tensor_layout::gemm::RowMajor>::value){
// B : K, N
constexpr auto b_thread_desc_k_n = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr));
return b_thread_desc_k_n;
} else {
// B : N/8, K, N8
constexpr auto b_thread_desc_n_k_n8 = make_naive_tensor_descriptor_packed(make_tuple(math::integer_divide_ceil(ThreadwiseGemm_Dispatch::ThreadMaxNr, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_thread_desc_n_k_n8;
}
}
#endif
static
constexpr
auto
GetAThreadSliceLength
()
{
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
,
KPerBlock
>
{};
}
else
{
// A : K, M
return
ck
::
Sequence
<
KPerBlock
,
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
>
{};
}
}
static
constexpr
auto
GetBThreadSliceLength
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
Sequence
<
KPerBlock
,
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
>
{};
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
>
{};
}
}
}
static
constexpr
auto
GetAThreadMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
,
0
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
0
,
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
);
}
}
static
constexpr
auto
GetBThreadMoveFwdStep
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
0
,
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
);
}
else
{
// B : N/8, K, N88;
return
ck
::
Sequence
<
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
0
,
0
>
{};
}
}
static
constexpr
ck
::
index_t
GetAThreadLoopOverDim
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
0
;
}
else
{
// A : K, M
return
1
;
}
}
static
constexpr
ck
::
index_t
GetBThreadLoopOverDim
()
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
1
;
}
else
{
// B : N/8, K, N88;
return
0
;
}
}
static
constexpr
auto
GetCBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// TODO:
}
}
static
constexpr
auto
GetCBlockSliceLength
()
{
return
ck
::
Sequence
<
MPerBlock
,
NPerBlock
>
{};
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
const
CGridDesc
&
c_grid_desc
)
{
{
#if 0
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
bool
is_valid
=
true
;
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if
constexpr
(
UseCLocalBuffer
)
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
{
// 1-stage prefetch always supported
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
}
is_valid
&=
false
;
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
}
else
else
{
{
return false;
// TODO: need check c grid is simple transform?
if
(
GemmN
%
8
!=
0
)
is_valid
&=
false
;
}
}
return
is_valid
;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
...
@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN
...
@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
{
{
ck
::
index_t
m_per_block
;
ck
::
index_t
m_per_block
=
MPerBlock
;
ck
::
index_t
n_per_block
;
ck
::
index_t
n_per_block
=
NPerBlock
;
ck
::
index_t
k_per_block
;
ck
::
index_t
k_per_block
=
KPerBlock
;
if
constexpr
(
MPerBlock
==
0
&&
NPerBlock
==
0
&&
KPerBlock
==
0
)
{}
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
else
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
{
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
m_per_block
=
MPerBlock
;
n_per_block
=
NPerBlock
;
constexpr
auto
a_block_copy_dim
=
AGridDesc
::
GetNumOfDimension
();
k_per_block
=
KPerBlock
;
}
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
const
auto
M
=
a_grid_desc
.
GetLength
(
I0
);
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
const
auto
N
=
b_grid_desc
.
GetLength
(
I1
);
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
const
auto
K
=
b_grid_desc
.
GetLength
(
I0
);
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
M
,
m_per_block
);
AElementwiseOperation
{});
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
N
,
n_per_block
);
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
constexpr
auto
a_block_desc
=
GetABlockDescriptor
();
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
constexpr
auto
a_block_slice_length
=
GetABlockSliceLength
();
BElementwiseOperation
{});
constexpr
auto
a_block_copy_dim
=
decltype
(
a_block_slice_length
)
::
Size
();
constexpr
auto
a_dim_access_order
=
GetABlockDimAccessOrder
();
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
constexpr
auto
a_block_move_step
=
GetABlockMoveFwdStep
();
ck
::
make_zero_multi_index
<
2
>
(),
constexpr
auto
a_thread_slice_length
=
GetAThreadSliceLength
();
c_grid_desc
,
constexpr
auto
a_thread_loop_over_dim
=
GetAThreadLoopOverDim
();
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
constexpr
auto
b_block_desc
=
GetBBlockDescriptor
();
constexpr
auto
b_block_slice_length
=
GetBBlockSliceLength
();
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
constexpr
auto
b_block_copy_dim
=
decltype
(
b_block_slice_length
)
::
Size
();
MemAlignmentByte
);
constexpr
auto
b_dim_access_order
=
GetBBlockDimAccessOrder
();
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
constexpr
auto
b_block_move_step
=
GetBBlockMoveFwdStep
();
MemAlignmentByte
);
constexpr
auto
b_thread_slice_length
=
GetBThreadSliceLength
();
DeviceAlignedMemCPU
c_block_mem
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
),
constexpr
auto
b_thread_loop_over_dim
=
GetBThreadLoopOverDim
();
MemAlignmentByte
);
constexpr
auto
c_block_desc
=
GetCBlockDescriptor
();
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
constexpr
auto
c_block_slice_length
=
GetCBlockSliceLength
();
constexpr
auto
c_block_move_step
=
ck
::
make_multi_index
(
0
,
NPerBlock
);
auto
a_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatA
,
// SrcData
FloatA
,
// DstData
decltype
(
a_grid_desc
),
// SrcDesc
decltype
(
a_block_desc
),
// DstDesc
AElementwiseOperation
,
// ElementwiseOperation
decltype
(
a_block_slice_length
),
// SliceLengths
decltype
(
a_dim_access_order
),
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
false
,
// SrcResetCoordinateAfterRun
true
// DstResetCoordinateAfterRun
>
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
a_block_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatB
,
// SrcData
FloatB
,
// DstData
decltype
(
b_grid_desc
),
// SrcDesc
decltype
(
b_block_desc
),
// DstDesc
BElementwiseOperation
,
// ElementwiseOperation
decltype
(
b_block_slice_length
),
// SliceLengths
decltype
(
b_dim_access_order
),
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
false
,
// SrcResetCoordinateAfterRun
true
// DstResetCoordinateAfterRun
>
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
b_block_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
auto
c_threadwise_copy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2
<
FloatC
,
// SrcData
FloatC
,
// DstData
decltype
(
c_block_desc
),
// SrcDesc
decltype
(
c_grid_desc
),
// DstDesc
BElementwiseOperation
,
// ElementwiseOperation
ck
::
Sequence
<
MPerBlock
,
NPerBlock
>
,
// SliceLengths
ck
::
Sequence
<
0
,
1
>
,
// DimAccessOrder
1
,
// VectorDim
1
,
// ScalarPerVector
ck
::
InMemoryDataOperationEnum_t
::
Set
,
// InMemoryDataOperationEnum_t
true
,
// SrcResetCoordinateAfterRun
false
// DstResetCoordinateAfterRun
>
(
c_block_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
MPerBlock
*
KPerBlock
*
sizeof
(
FloatA
),
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
KPerBlock
*
NPerBlock
*
sizeof
(
FloatB
),
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
MPerBlock
*
NPerBlock
*
sizeof
(
FloatC
),
MemAlignmentByte
);
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum_t
::
Global
>
(
reinterpret_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
reinterpret_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
reinterpret_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
_t
::
Global
>
(
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum_t
::
Global
>
(
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
),
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
));
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
auto
blockwise_gemm
=
:
c_grid_desc
.
GetElementSpaceSize
());
BlockwiseGemmAvx2_MxN
<
FloatA
,
// FloatA,
FloatB
,
// FloatB,
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
FloatC
,
// FloatC,
FloatA
,
// FloatA,
AccDataType
,
// AccDataType,
FloatB
,
// FloatB,
decltype
(
a_block_desc
),
// ABlockDesc,
FloatC
,
// FloatC,
decltype
(
b_block_desc
),
// BBlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
c_block_desc
),
// CBlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
a_block_slice_length
),
// ABlockSliceLengths,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
)),
// CBlockDesc,
decltype
(
b_block_slice_length
),
// BBlockSliceLengths,
KPerBlock
,
// KPerBlock,
decltype
(
c_block_slice_length
),
// CBlockSliceLengths,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
decltype
(
a_thread_slice_length
),
// AThreadSliceLength,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
decltype
(
b_thread_slice_length
),
// BThreadSliceLength,
// gemm MN to utilize micro kernel>{};
a_thread_loop_over_dim
,
// AThreadLoopOverDim, // thread slice
// loop over on block slice. 1d is enough
// for now
b_thread_loop_over_dim
,
// BThreadLoopOverDim,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering
// TODO: openmp aware ordering
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
{
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
,
0
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel for
#pragma omp parallel for
for
(
ck
::
index_t
gid
=
0
;
gid
<
grid_size
;
gid
++
)
for
(
ck
::
index_t
gid
=
0
;
gid
<
grid_size
;
gid
++
)
{
{
ck
::
index_t
i_mc
=
(
gid
/
grid_n
)
*
m_per_block
;
ck
::
index_t
i_mc
=
(
gid
/
grid_n
)
*
m_per_block
;
ck
::
index_t
i_nc
=
(
gid
%
grid_n
)
*
n_per_block
;
ck
::
index_t
i_nc
=
(
gid
%
grid_n
)
*
n_per_block
;
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
M
-
i_mc
,
m_per_block
);
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
N
-
i_nc
,
n_per_block
);
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
// pack_b
nc_size
=
math
::
integer_least_multiple
(
b_threadwise_copy
.
RunGeneric
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_move_step
);
if
(
i_nc
==
0
)
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
i_nc
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
));
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
}
else
{
{
// pack_a
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
a_threadwise_copy
.
RunGeneric
(
ck
::
make_multi_index
(
i_mc
,
i_nc
));
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_move_step
);
}
}
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
K
;
i_kc
+=
k_per_block
)
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
Gemm
K
;
i_kc
+=
k_per_block
)
{
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
K
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
Gemm
K
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
// printf("A ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//}
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__);
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
make_zero_multi_index
<
a_block_copy_dim
>
(),
...
@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN
...
@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN
make_zero_multi_index
<
b_block_copy_dim
>
(),
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_desc
,
c_block_desc
,
c_block_buf
,
c_block_buf
,
make_zero_multi_index
<
2
>
());
make_zero_multi_index
<
2
>
(),
i_kc
!=
0
);
// printf("[%d] 2222 \n",__LINE__);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
{
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
// printf("[%d] 2222 \n",__LINE__);
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
}
}
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
// i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
n_per_block
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
);
// only parallel in gemm m dim
#pragma omp parallel for
for
(
ck
::
index_t
i_mc
=
0
;
i_mc
<
GemmM
;
i_mc
+=
m_per_block
)
{
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
{
{
c_threadwise_copy
.
RunGeneric
(
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
MoveDstSliceWindow
(
c_grid_desc
,
c_block_move_step
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
// TODO: if use local C buffer, then this nc loop need to loop only once
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
{
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
!
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
}
}
if
((
i_kc
+
k_per_block
)
<
GemmK
)
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
}
}
}
}
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
afc7d431
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_
gemm_
param.hpp"
namespace
ck
{
namespace
ck
{
namespace
cpu
{
namespace
cpu
{
...
@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
"test %%edi, %%edi
\n
"
"je L_GemmAvx2_MxN_6x16_Store_C%=
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
...
@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_6x16_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
};
// clang-format off
// clang-format off
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
(
param
->
accmulate_c
){
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
while
(
Kr
>
4
){
#pragma unroll
#pragma unroll
...
@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
if
constexpr
(
NonTemporalStore
)
{
if
constexpr
(
NonTemporalStore
)
{
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm1
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
...
@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
"test %%edi, %%edi
\n
"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
"je L_GemmAvx2_MxN_4x24_Store_C%=
\n
"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Nr >16)
\n
vaddps 64(%%rax), %%ymm2, %%ymm2
\n
.endif
\n
"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm3, %%ymm3
\n
.endif
\n
"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vaddps 32(%%rbx), %%ymm4, %%ymm4
\n
.endif
\n
"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)
\n
vaddps 64(%%rbx), %%ymm5, %%ymm5
\n
.endif
\n
"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 2)
\n
vaddps (%%rcx), %%ymm6, %%ymm6
\n
.endif
\n
"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vaddps 32(%%rcx), %%ymm7, %%ymm7
\n
.endif
\n
"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)
\n
vaddps 64(%%rcx), %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vaddps (%%rdx), %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vaddps 64(%%rdx), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_4x24_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
};
// clang-format off
// clang-format off
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
(
param
->
accmulate_c
)
{
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_loadu_ps
(
p_c
+
2
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
while
(
Kr
>
4
){
#pragma unroll
#pragma unroll
...
@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
...
@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
6
][
2
]
=
{
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
6
][
2
]
=
{
{
{
ThreadwiseGemm_
6x16
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
1x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
5x16
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
2x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
4x16
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
3x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
3x16
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
4x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
2x16
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
5x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
1x16
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
6x16
_t
::
Run
,
},
},
};
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
{
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
5
&&
in
>=
0
&&
in
<=
1
);
return
dispatch_table
[
mr
][
nr
](
param
);
return
dispatch_table
[
mr
][
nr
](
param
);
}
}
};
};
...
@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
...
@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
4
][
3
]
=
{
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
4
][
3
]
=
{
{
{
ThreadwiseGemm_
4x24
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
1x24
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
3x24
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
2x24
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
2x24
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
3x24
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
1x24
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
4x24
_t
::
Run
,
},
},
};
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
{
return
dispatch_table
[
mr
][
nr
](
param
);
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
3
&&
in
>=
0
&&
in
<=
2
);
return
dispatch_table
[
im
][
in
](
param
);
}
}
};
};
...
...
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
→
include/ck/tensor_operation/cpu/thread/threadwise_
gemm_
param.hpp
View file @
afc7d431
#ifndef CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_
GEMM_
PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_
GEMM_
PARAM_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "math.hpp"
#include "math.hpp"
...
@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
...
@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t
ldb
;
// in unit of byte
uint64_t
ldb
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
float
alpha
;
float
alpha
;
uint32_t
_pack0
;
int
accmulate_c
;
// if 1, need load C and add into current fma. if 0, direct store out c result
}
__attribute__
((
packed
));
}
__attribute__
((
packed
));
}
// namespace cpu
}
// namespace cpu
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
View file @
afc7d431
...
@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
...
@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
"wrong! cannot evenly divide"
);
int
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
int
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
int
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
int
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
int
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
int
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
int
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
int
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
printf
(
"N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d
\n
"
,
N
,
Hi
,
Wi
,
C
,
Ho
,
Wo
,
Fy
,
Fx
,
Dy
,
Sy
,
Dx
,
Sx
,
Py
,
Px
);
}
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
...
@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
// std::cout<<"num_access:"<<num_access<<std::endl;
std
::
cout
<<
"src hidden:"
<<
SrcDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
std
::
cout
<<
"dst hidden:"
<<
DstDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using src_vector_t = typename src_vector_type::type;
...
@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
...
@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
}
});
});
#endif
const
auto
src_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
src_slice_step
=
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
const
auto
dst_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
dst_slice_step
=
make_tensor_coordinate_step
(
dst_desc
,
to_multi_index
(
dst_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
for
(
auto
idx_id
=
0
;
idx_id
<
num_access
;
idx_id
++
)
{
using
src_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
printf
(
"[%s] "
,
is_src_valid
?
"y"
:
"n"
);
print_multi_index
(
src_coord_
.
GetIndex
());
printf
(
"----"
);
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_
(
dst_vector_container
.
template
AsType
<
dst_vector_t
>(),
src_vector_container
.
template
AsType
<
src_vector_t
>());
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
printf
(
" -> "
);
print_multi_index
(
dst_coord_
.
GetIndex
());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf
(
"
\n
"
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>());
// move coordinate
if
(
idx_id
!=
num_access
-
1
)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_slice_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_slice_step
);
}
}
// move coordinate back to slice origin (or not)
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
0 → 100644
View file @
afc7d431
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp"
#include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
namespace
ck
{
namespace
cpu
{
namespace
avx2_util
{
inline
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
_mm256_loadu_ps
(
p_src
+
0
));
_mm256_storeu_ps
(
p_dst
+
8
,
_mm256_loadu_ps
(
p_src
+
8
));
p_dst
+=
16
;
p_src
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
_mm256_loadu_ps
(
p_src
));
p_dst
+=
8
;
p_src
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
_mm_loadu_ps
(
p_src
));
p_dst
+=
4
;
p_src
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
_mm_loadu_si64
(
p_src
));
p_dst
+=
2
;
p_src
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
*
p_src
;
}
}
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
__m256
ymm
=
_mm256_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
__m128
xmm
=
_mm_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
ymm
);
_mm256_storeu_ps
(
p_dst
+
8
,
ymm
);
p_dst
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
ymm
);
p_dst
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
xmm
);
p_dst
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
xmm
);
p_dst
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
*
reinterpret_cast
<
const
float
*>
(
&
value
);
}
}
inline
void
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
)
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
__m256
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
r0
=
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
);
r1
=
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
);
r2
=
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
);
r3
=
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
);
r4
=
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
);
r5
=
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
);
r6
=
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
);
r7
=
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
);
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
t2
=
_mm256_unpacklo_ps
(
r2
,
r3
);
t3
=
_mm256_unpackhi_ps
(
r2
,
r3
);
t4
=
_mm256_unpacklo_ps
(
r4
,
r5
);
t5
=
_mm256_unpackhi_ps
(
r4
,
r5
);
t6
=
_mm256_unpacklo_ps
(
r6
,
r7
);
t7
=
_mm256_unpackhi_ps
(
r6
,
r7
);
r0
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r1
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r2
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r3
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r4
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r5
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r6
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r7
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
t0
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x20
);
t1
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x20
);
t2
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x20
);
t3
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x20
);
t4
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x31
);
t5
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x31
);
t6
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x31
);
t7
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x31
);
_mm256_storeu_ps
(
p_dst
+
0
*
stride_dst
,
t0
);
_mm256_storeu_ps
(
p_dst
+
1
*
stride_dst
,
t1
);
_mm256_storeu_ps
(
p_dst
+
2
*
stride_dst
,
t2
);
_mm256_storeu_ps
(
p_dst
+
3
*
stride_dst
,
t3
);
_mm256_storeu_ps
(
p_dst
+
4
*
stride_dst
,
t4
);
_mm256_storeu_ps
(
p_dst
+
5
*
stride_dst
,
t5
);
_mm256_storeu_ps
(
p_dst
+
6
*
stride_dst
,
t6
);
_mm256_storeu_ps
(
p_dst
+
7
*
stride_dst
,
t7
);
}
}
// namespace avx2_util
using
ConvolutionForwardSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
;
using
ConvolutionForwardGemmKSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
;
// assume input -> a matrix
// assume input -> MC * KC
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
N
=
1
;
Hi
=
1
;
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
// gemm_m
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// gemm_k
Ho
=
1
;
Wo
=
Wi
;
Fy
=
1
;
Fx
=
1
;
Dy
=
1
;
Sy
=
1
;
Dx
=
1
;
Sx
=
1
;
Py
=
0
;
Px
=
0
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Fy
=
1
;
Fx
=
1
;
Dy
=
1
;
Sy
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
coefficients_
[
Number
<
0
>
{}];
Dx
=
1
;
Sx
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
coefficients_
[
Number
<
0
>
{}];
Py
=
0
;
Px
=
0
;
}
else
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
}
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi
=
Sx
*
C
;
input_offset_ovf_wi_acc_hi
=
Sy
*
Wi
*
C
-
Wo
*
Sx
*
C
;
input_offset_ovf_hi_acc_n
=
Hi
*
Wi
*
C
-
Ho
*
Sy
*
Wi
*
C
;
// input_offset_acc_c = 1;
input_offset_ovf_c_acc_x
=
Dx
*
C
-
C
;
input_offset_ovf_x_acc_y
=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
src_offset
=
-
Py
*
Wi
*
C
-
Px
*
C
;
i_n
=
0
;
i_c
=
0
;
i_hi
=
-
Py
;
i_wi
=
-
Px
;
i_ho
=
0
;
i_wo
=
0
;
i_y
=
0
;
i_x
=
0
;
i_gemm_k
=
0
;
#if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
#endif
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
i_wi
=
idx_m
;
i_c
=
idx_k
;
src_offset
=
i_wi
*
C
+
i_c
;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c
=
idx_k
;
i_x
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
;
i_wi
=
i_wo
*
Sx
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
}
else
{
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
(
idx_k
==
0
)
{
i_c
=
0
;
i_x
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
-
Py
;
i_wi
=
i_wo
*
Sx
-
Px
;
}
else
{
i_c
=
idx_k
%
C
;
i_x
=
(
idx_k
/
C
)
%
Fx
;
i_y
=
(
idx_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
}
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
p_src
;
}
else
{
const
ck
::
index_t
m_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block);
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
ck
::
index_t
i_m_itr
=
m_per_block
;
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
);
i_m_itr
-=
8
;
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
C
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
p_dst
+=
4
*
k_per_block
;
p_src
+=
4
*
C
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+=
2
*
k_per_block
;
p_src
+=
2
*
C
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
while
(
i_m_itr
>
0
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
p_dst
+=
k_per_block
;
i_wo_itr
++
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_ho_itr
++
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
}
else
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
// c % k_per_block == 0, so every time k_per_block here is the same
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while
(
i_m_itr
>
0
)
{
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
else
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
// printf("[%d] \n", __LINE__);
}
else
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while
(
i_m_itr
>
0
)
{
/*** go along Gemm K ***/
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
ck
::
index_t
i_wi_itr_k
=
i_wi_itr
;
ck
::
index_t
i_hi_itr_k
=
i_hi_itr
;
ck
::
index_t
i_c_itr_k
=
i_c
;
ck
::
index_t
i_y_itr_k
=
i_y
;
ck
::
index_t
i_x_itr_k
=
i_x
;
ck
::
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>
0
)
{
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
);
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
p_dst_k
+=
current_k_block
;
p_src_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
;
if
(
i_c_itr_k
>=
C
)
{
i_c_itr_k
=
0
;
i_x_itr_k
++
;
i_wi_itr_k
+=
Dx
;
p_src_k
+=
input_offset_ovf_c_acc_x
;
}
if
(
i_x_itr_k
>=
Fx
)
{
i_x_itr_k
=
0
;
i_y_itr_k
++
;
i_hi_itr_k
+=
Dy
;
p_src_k
+=
input_offset_ovf_x_acc_y
;
}
i_k_itr
-=
current_k_block
;
}
/*** go along Gemm K ***/
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
{
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
}
}
}
}
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
i_c
+=
move_k
;
src_offset
+=
move_k
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
i_c
+=
move_k
;
src_offset
+=
move_k
;
}
else
{
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
// c % k_per_block == 0, so every time k_per_block here is the same
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
i_c
+=
move_k
;
src_offset
+=
move_k
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if
(
i_c
>=
C
)
{
i_c
=
0
;
i_x
++
;
i_wi
+=
Dx
;
src_offset
+=
Dx
*
C
-
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
if
(
i_x
>=
Fx
)
{
i_x
=
0
;
i_y
++
;
i_wi
=
i_wi
-
Fx
*
Dx
;
i_hi
+=
Dy
;
src_offset
+=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
}
else
{
i_gemm_k
+=
move_k
;
i_c
=
i_gemm_k
%
C
;
i_x
=
(
i_gemm_k
/
C
)
%
Fx
;
i_y
=
(
i_gemm_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
}
}
}
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_n
;
ck
::
index_t
i_c
;
ck
::
index_t
i_hi
;
ck
::
index_t
i_wi
;
ck
::
index_t
i_ho
;
ck
::
index_t
i_wo
;
ck
::
index_t
i_y
;
ck
::
index_t
i_x
;
ck
::
index_t
i_gemm_k
;
ck
::
index_t
N
;
// ck::index_t K;
ck
::
index_t
C
;
ck
::
index_t
Hi
;
ck
::
index_t
Wi
;
ck
::
index_t
Ho
;
ck
::
index_t
Wo
;
ck
::
index_t
Sy
;
ck
::
index_t
Sx
;
ck
::
index_t
Dy
;
ck
::
index_t
Dx
;
ck
::
index_t
Py
;
ck
::
index_t
Px
;
ck
::
index_t
Fy
;
ck
::
index_t
Fx
;
intptr_t
input_offset_acc_wi
;
intptr_t
input_offset_ovf_wi_acc_hi
;
intptr_t
input_offset_ovf_hi_acc_n
;
// intptr_t input_offset_acc_c;
intptr_t
input_offset_ovf_c_acc_x
;
intptr_t
input_offset_ovf_x_acc_y
;
intptr_t
src_offset
;
// keep this as pointer type in case we have negative offset
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
GemmN1
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_n0
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_n1
=
src_slice_origin_idx
[
Number
<
2
>
{}];
i_gemm_n
=
idx_n0
*
GemmN1
+
idx_n1
;
// i_gemm_k = idx_k;
src_offset
=
idx_n0
*
GemmK
*
GemmN1
+
idx_k
+
idx_n1
*
GemmN1
;
// Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
// TODO: weight NHWC not support this
}
else
{
const
ck
::
index_t
n_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for
(
index_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
{
ck
::
index_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
8
);
ck
::
index_t
i_k_itr
=
k_per_block
;
if
(
current_n_8
==
8
)
{
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
while
(
i_k_itr
>=
8
)
{
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
);
p_dst_k
+=
8
*
8
;
p_src_k
+=
8
;
i_k_itr
-=
8
;
}
if
(
i_k_itr
&
4
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
2
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
2
];
p_dst_k
[
3
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
3
];
p_dst_k
+=
4
*
8
;
p_src_k
+=
4
;
}
if
(
i_k_itr
&
2
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
+=
2
*
8
;
p_src_k
+=
2
;
}
if
(
i_k_itr
&
1
)
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
}
}
else
{
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
for
(
index_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
{
for
(
index_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
{
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
float
v
=
i_current_n_itr
<
GemmN
?
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
]
:
.0
f
;
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
}
}
}
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
GemmK
;
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
ck
::
index_t
move_n0
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset
+=
move_k
+
move_n0
*
GemmK
*
GemmN1
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_gemm_n
;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck
::
index_t
GemmN1
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
dst_desc
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
DstGemmM
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
DstGemmN
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
src_offset
=
0
;
dst_offset
=
0
;
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
if
constexpr
(
BypassTransfer
)
{
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
i_src_gemm_m
*
DstGemmN
+
i_src_gemm_n
;
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
dst_slice_origin_idx
)
{
i_dst_gemm_m
=
dst_slice_origin_idx
[
Number
<
0
>
{}];
i_dst_gemm_n
=
dst_slice_origin_idx
[
Number
<
1
>
{}];
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
constexpr
(
BypassTransfer
)
{
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
m_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
const
ck
::
index_t
n_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_src
+=
8
*
n_per_block
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_dst_gemm_m
;
ck
::
index_t
i_dst_gemm_n
;
ck
::
index_t
DstGemmM
;
ck
::
index_t
DstGemmN
;
intptr_t
src_offset
;
intptr_t
dst_offset
;
};
}
// namespace cpu
}
// namespace ck
#endif
library/include/ck/library/host_tensor/device.hpp
View file @
afc7d431
...
@@ -121,7 +121,11 @@ template <typename... Args, typename F>
...
@@ -121,7 +121,11 @@ template <typename... Args, typename F>
float
launch_and_time_cpu_kernel
(
F
kernel
,
int
nrepeat
,
Args
...
args
)
float
launch_and_time_cpu_kernel
(
F
kernel
,
int
nrepeat
,
Args
...
args
)
{
{
WallTimer
timer
;
WallTimer
timer
;
kernel
(
args
...);
int
nwarmup
=
3
;
for
(
int
i
=
0
;
i
<
nwarmup
;
i
++
)
kernel
(
args
...);
timer
.
Start
();
timer
.
Start
();
for
(
int
i
=
0
;
i
<
nrepeat
;
i
++
)
for
(
int
i
=
0
;
i
<
nrepeat
;
i
++
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
afc7d431
...
@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
...
@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
P
assThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
P
T
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
WeiType
,
...
@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
...
@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static
constexpr
auto
ConvFwd1x1S1P0
=
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
// clang-format off
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
120
,
64
,
4
,
24
,
true
,
true
,
false
),
float
,
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
144
,
128
,
4
,
24
,
true
,
true
,
false
),
float
,
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
float
,
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
float
,
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
288
,
128
,
4
,
24
,
true
,
true
,
false
)
>
;
PassThrough
,
// clang-format on
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
256
,
128
,
64
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
512
,
256
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
1024
,
144
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>>
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
afc7d431
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
namespace
ck
{
using
F32
=
float
;
namespace
tensor_operation
{
using
F16
=
ck
::
half_t
;
namespace
cpu
{
namespace
device
{
namespace
ck
{
namespace
device_conv2d_fwd_avx2_instance
{
namespace
tensor_operation
{
namespace
cpu
{
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
namespace
device
{
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
namespace
device_conv2d_fwd_avx2_instance
{
}
// namespace device_conv2d_fwd_avx2_instance
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
}
// namespace device
}
// namespace cpu
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
}
// namespace tensor_operation
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
}
// namespace ck
}
// namespace device_conv2d_fwd_avx2_instance
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
}
// namespace device
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
}
// namespace cpu
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
}
// namespace tensor_operation
}
// namespace ck
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
{
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
float
max_diff
=
1e-6
;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
template
<
typename
T
>
{
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
{
if
(
max_diff
<
diff
)
int
error_count
=
0
;
{
float
max_diff
=
1e-6
;
return
false
;
}
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
}
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
return
true
;
if
(
max_diff
<
diff
)
}
{
error_count
++
;
int
main
(
int
argc
,
char
*
argv
[])
printf
(
"idx:%3d, ref:%f, res:%f (diff:%f)
\n
"
,
{
i
,
int
data_type
=
0
;
double
(
ref
.
mData
[
i
]),
int
init_method
=
0
;
double
(
result
.
mData
[
i
]),
diff
);
// Conv shape
}
ck
::
index_t
N
=
128
;
}
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
return
error_count
==
0
;
ck
::
index_t
Y
=
3
;
}
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
float
calculate_gflops
()
{}
ck
::
index_t
Wi
=
71
;
ck
::
index_t
conv_stride_h
=
2
;
int
main
(
int
argc
,
char
*
argv
[])
ck
::
index_t
conv_stride_w
=
2
;
{
ck
::
index_t
conv_dilation_h
=
1
;
int
data_type
=
0
;
ck
::
index_t
conv_dilation_w
=
1
;
int
init_method
=
0
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
// Conv shape
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
N
=
2
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
if
(
argc
==
1
)
ck
::
index_t
Y
=
3
;
{
ck
::
index_t
X
=
3
;
data_type
=
1
;
ck
::
index_t
Hi
=
71
;
init_method
=
1
;
ck
::
index_t
Wi
=
71
;
}
ck
::
index_t
conv_stride_h
=
1
;
else
if
(
argc
==
3
)
ck
::
index_t
conv_stride_w
=
1
;
{
ck
::
index_t
conv_dilation_h
=
1
;
data_type
=
std
::
stoi
(
argv
[
1
]);
ck
::
index_t
conv_dilation_w
=
1
;
init_method
=
std
::
stoi
(
argv
[
2
]);
ck
::
index_t
in_left_pad_h
=
1
;
}
ck
::
index_t
in_left_pad_w
=
1
;
else
if
(
argc
==
18
)
ck
::
index_t
in_right_pad_h
=
1
;
{
ck
::
index_t
in_right_pad_w
=
1
;
data_type
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
if
(
argc
==
1
)
{
N
=
std
::
stoi
(
argv
[
3
]);
data_type
=
0
;
K
=
std
::
stoi
(
argv
[
4
]);
init_method
=
1
;
C
=
std
::
stoi
(
argv
[
5
]);
}
Y
=
std
::
stoi
(
argv
[
6
]);
else
if
(
argc
==
3
)
X
=
std
::
stoi
(
argv
[
7
]);
{
Hi
=
std
::
stoi
(
argv
[
8
]);
data_type
=
std
::
stoi
(
argv
[
1
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
conv_stride_h
=
std
::
stoi
(
argv
[
10
]);
}
conv_stride_w
=
std
::
stoi
(
argv
[
11
]);
else
if
(
argc
==
18
)
conv_dilation_h
=
std
::
stoi
(
argv
[
12
]);
{
conv_dilation_w
=
std
::
stoi
(
argv
[
13
]);
data_type
=
std
::
stoi
(
argv
[
1
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
N
=
std
::
stoi
(
argv
[
3
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
K
=
std
::
stoi
(
argv
[
4
]);
}
C
=
std
::
stoi
(
argv
[
5
]);
else
Y
=
std
::
stoi
(
argv
[
6
]);
{
X
=
std
::
stoi
(
argv
[
7
]);
printf
(
"arg1: data type (0=fp32, 1=fp16)
\n
"
);
Hi
=
std
::
stoi
(
argv
[
8
]);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
Wi
=
std
::
stoi
(
argv
[
9
]);
printf
(
"arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
conv_stride_h
=
std
::
stoi
(
argv
[
10
]);
"RightPx
\n
"
);
conv_stride_w
=
std
::
stoi
(
argv
[
11
]);
exit
(
1
);
conv_dilation_h
=
std
::
stoi
(
argv
[
12
]);
}
conv_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
,
auto
acc_type
)
{
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
using
InDataType
=
decltype
(
input_type
);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
using
WeiDataType
=
decltype
(
wei_type
);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
using
OutDataType
=
decltype
(
out_type
);
}
using
AccDataType
=
decltype
(
acc_type
);
else
{
using
ReferenceConvBwdInstance
=
printf
(
"arg1: data type (0=fp32, 1=fp16)
\n
"
);
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
WeiDataType
,
printf
(
"arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
OutDataType
,
"RightPx
\n
"
);
AccDataType
,
exit
(
1
);
InElementOp
,
}
WeiElementOp
,
OutElementOp
>
;
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
using
InDataType
=
decltype
(
input_type
);
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
using
WeiDataType
=
decltype
(
wei_type
);
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
using
OutDataType
=
decltype
(
out_type
);
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
using
ReferenceConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
InDataType
,
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
WeiDataType
,
OutDataType
,
const
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
{{
Hi
,
Wi
}};
InElementOp
,
const
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
{{
Y
,
X
}};
WeiElementOp
,
const
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
{{
Ho
,
Wo
}};
OutElementOp
>
;
const
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
{{
conv_stride_h
,
conv_stride_w
}};
const
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
{{
conv_dilation_h
,
conv_dilation_w
}};
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{{
in_left_pad_h
,
in_left_pad_w
}};
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{{
in_right_pad_h
,
in_right_pad_w
}};
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
std
::
size_t
C_
,
std
::
size_t
H_
,
const
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
{{
Hi
,
Wi
}};
std
::
size_t
W_
)
{
const
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
{{
Y
,
X
}};
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N_
,
C_
,
H_
,
W_
}),
const
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
{{
Ho
,
Wo
}};
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
const
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
{{
conv_stride_h
,
conv_stride_w
}};
};
const
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
{{
conv_dilation_h
,
conv_dilation_w
}};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{{
in_left_pad_h
,
in_left_pad_w
}};
Tensor
<
OutDataType
>
out_n_ho_wo_k
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{{
in_right_pad_h
,
in_right_pad_w
}};
Tensor
<
WeiDataType
>
wei_k_y_x_c
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
InDataType
>
in_n_hi_wi_c_host_result
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
Tensor
<
InDataType
>
in_n_hi_wi_c_device_result
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
std
::
size_t
C_
,
std
::
size_t
H_
,
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_hi_wi_c_host_result
.
mDesc
<<
std
::
endl
;
std
::
size_t
W_
)
{
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_y_x_c
.
mDesc
<<
std
::
endl
;
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N_
,
C_
,
H_
,
W_
}),
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_ho_wo_k
.
mDesc
<<
std
::
endl
;
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
};
switch
(
init_method
)
{
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
case
0
:
break
;
Tensor
<
WeiDataType
>
wei_k_c_y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
case
1
:
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
break
;
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
case
2
:
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
std
::
cout
<<
"LPad(H, W):"
<<
in_left_pad_h
<<
","
<<
in_left_pad_w
break
;
<<
", RPad(H, W):"
<<
in_right_pad_h
<<
","
<<
in_right_pad_w
default:
<<
", Stride(H, W):"
<<
conv_stride_h
<<
", "
<<
conv_stride_w
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{
1
});
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
}
switch
(
init_method
)
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
{
in_n_hi_wi_c_device_result
.
mDesc
.
GetElementSpace
(),
case
0
:
break
;
AVX2_DATA_ALIGNMENT
);
case
1
:
DeviceAlignedMemCPU
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
DeviceAlignedMemCPU
out_device_buf
(
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
sizeof
(
OutDataType
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
out_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
break
;
wei_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
case
2
:
// reset input to zero
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
in_n_hi_wi_c_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
0
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{});
in_device_buf
.
ToDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
break
;
case
3
:
// get host result
{
#define PACK_32(v24, v16, v8, v0) \
auto
ref_conv
=
ReferenceConvFwdInstance
{};
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
for
(
auto
i_n
=
0
;
i_n
<
N
;
i_n
++
)
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_hi_wi_c_host_result
,
{
wei_k_y_x_c
,
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
out_n_ho_wo_k
,
{
conv_filter_strides
,
for
(
auto
i_hi
=
0
;
i_hi
<
Hi
;
i_hi
++
)
conv_filter_dilations
,
{
input_left_pads
,
for
(
auto
i_wi
=
0
;
i_wi
<
Wi
;
i_wi
++
)
input_right_pads
,
{
InElementOp
{},
uint32_t
v
=
PACK_32
(
i_n
,
i_c
,
i_hi
,
i_wi
);
WeiElementOp
{},
in_n_c_hi_wi
(
i_n
,
i_c
,
i_hi
,
i_wi
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
OutElementOp
{});
}
ref_invoker
.
Run
(
ref_argument
);
}
}
}
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
for
(
auto
i_k
=
0
;
i_k
<
K
;
i_k
++
)
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
// add device Conv instances
{
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
for
(
auto
i_y
=
0
;
i_y
<
Y
;
i_y
++
)
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
float
>
&&
for
(
auto
i_x
=
0
;
i_x
<
X
;
i_x
++
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
{
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
uint32_t
v
=
PACK_32
(
i_k
,
i_c
,
i_y
,
i_x
);
{
wei_k_c_y_x
(
i_k
,
i_c
,
i_y
,
i_x
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
}
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
}
}
}
if
(
conv_ptrs
.
size
()
<=
0
)
break
;
{
default:
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0
,
1
});
}
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1
,
1
});
}
// profile device Conv instances
bool
success
=
true
;
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
(),
for
(
auto
&
conv_ptr
:
conv_ptrs
)
AVX2_DATA_ALIGNMENT
);
{
DeviceAlignedMemCPU
wei_device_buf
(
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
DeviceAlignedMemCPU
out_device_buf
(
sizeof
(
OutDataType
)
*
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
out_n_k_ho_wo_host_result
.
mDesc
.
GetElementSpace
(),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
AVX2_DATA_ALIGNMENT
);
N
,
K
,
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
C
,
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
input_spatial_lengths
,
filter_spatial_lengths
,
// get host result
output_spatial_lengths
,
{
conv_filter_strides
,
auto
ref_conv
=
ReferenceConvFwdInstance
{};
conv_filter_dilations
,
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
input_left_pads
,
input_right_pads
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi
,
InElementOp
{},
wei_k_c_y_x
,
WeiElementOp
{},
out_n_k_ho_wo_host_result
,
OutElementOp
{});
conv_filter_strides
,
conv_filter_dilations
,
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
input_left_pads
,
{
input_right_pads
,
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
InElementOp
{},
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
1
);
WeiElementOp
{},
OutElementOp
{});
in_device_buf
.
FromDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
ref_invoker
.
Run
(
ref_argument
);
}
if
(
!
check_out
(
in_n_hi_wi_c_host_result
,
in_n_hi_wi_c_device_result
))
{
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
success
=
false
;
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
}
else
// add device Conv instances
{
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
float
>
&&
}
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
else
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
{
std
::
cout
<<
"Not support Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
}
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
}
if
(
success
)
if
(
conv_ptrs
.
size
()
<=
0
)
{
{
std
::
cout
<<
"test conv2d fwd cpu : Pass"
<<
std
::
endl
;
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
return
0
;
}
}
else
// profile device Conv instances
{
bool
success
=
true
;
std
::
cout
<<
"test conv2d fwd cpu: Fail "
<<
std
::
endl
;
double
fastest_kernel_time
=
std
::
numeric_limits
<
double
>::
max
();
return
-
1
;
std
::
string
fastest_kernel_name
=
""
;
}
double
fastest_kernel_gflops
=
0
;
};
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
if
(
data_type
==
0
)
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
{
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
return
Run
(
F32
(),
F32
(),
F32
(),
F32
());
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
}
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
else
if
(
data_type
==
1
)
N
,
{
K
,
return
Run
(
F16
(),
F16
(),
F16
(),
F32
());
C
,
}
input_spatial_lengths
,
else
filter_spatial_lengths
,
{
output_spatial_lengths
,
return
1
;
conv_filter_strides
,
}
conv_filter_dilations
,
}
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
double
time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
10
);
double
total_flop
=
static_cast
<
double
>
(
2
)
*
N
*
C
*
Ho
*
Wo
*
K
*
Y
*
X
;
double
gflops
=
(
total_flop
*
1e-6
)
/
time
;
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
if
(
!
check_out
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
))
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
}
else
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
", Time:"
<<
time
<<
"ms, Gflops:"
<<
gflops
<<
std
::
endl
;
if
(
time
<
fastest_kernel_time
)
{
fastest_kernel_time
=
time
;
fastest_kernel_name
=
conv_ptr
->
GetTypeString
();
fastest_kernel_gflops
=
gflops
;
}
}
}
else
{
std
::
cout
<<
"Not support Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
}
if
(
fastest_kernel_time
!=
std
::
numeric_limits
<
double
>::
max
())
{
std
::
cout
<<
" fastest:"
<<
fastest_kernel_name
<<
", time:"
<<
fastest_kernel_time
<<
"ms, Gflops:"
<<
fastest_kernel_gflops
<<
std
::
endl
;
}
return
0
;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if
(
data_type
==
0
)
{
return
Run
(
F32
(),
F32
(),
F32
());
}
else
{
return
1
;
}
}
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
View file @
afc7d431
...
@@ -226,6 +226,8 @@ int main(int argc, char** argv)
...
@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static
constexpr
ck
::
index_t
nDim
=
static
constexpr
ck
::
index_t
nDim
=
ck
::
remove_reference_t
<
decltype
(
input_desc
)
>::
GetNumOfDimension
();
ck
::
remove_reference_t
<
decltype
(
input_desc
)
>::
GetNumOfDimension
();
input_desc
.
Print
();
auto
threadwise_transfer
=
threadwise_transfer_t
{
input_desc
,
auto
threadwise_transfer
=
threadwise_transfer_t
{
input_desc
,
ck
::
make_zero_multi_index
<
nDim
>
(),
ck
::
make_zero_multi_index
<
nDim
>
(),
input_cblock_desc
,
input_cblock_desc
,
...
...
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
afc7d431
...
@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk,
...
@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk,
float
*
private_c
=
mat_c
+
tid
*
m
*
n
;
float
*
private_c
=
mat_c
+
tid
*
m
*
n
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
p_a
=
mat_a
;
param
.
p_a
=
mat_a
;
param
.
p_b
=
mat_b
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
private_c
;
param
.
p_c
=
private_c
;
param
.
Kr
=
k
;
param
.
Kr
=
k
;
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
FloatA
);
param
.
lda
=
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
k
:
m
)
*
sizeof
(
FloatA
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
alpha
=
alpha
;
param
.
accmulate_c
=
0
;
memset
(
private_c
,
0
,
m
*
n
*
sizeof
(
float
));
memset
(
private_c
,
0
,
m
*
n
*
sizeof
(
float
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment