Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
afc7d431
Commit
afc7d431
authored
Apr 24, 2022
by
carlushuang
Browse files
avx2 gemm now works for single thread
parent
07af8343
Changes
13
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2308 additions
and
962 deletions
+2308
-962
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+160
-85
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
...ion/cpu/device/convolution_forward_specialization_cpu.hpp
+13
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+174
-36
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+248
-424
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+108
-63
include/ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
.../ck/tensor_operation/cpu/thread/threadwise_gemm_param.hpp
+3
-3
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
...tion/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
+109
-0
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+1084
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+5
-1
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+30
-46
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+363
-296
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
+2
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+9
-8
No files found.
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
afc7d431
...
@@ -13,21 +13,10 @@ namespace cpu {
...
@@ -13,21 +13,10 @@ namespace cpu {
template
<
typename
FloatA
,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AccDataType
,
typename
ABlockDesc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
BBlockDesc
,
typename
CBlockDesc
,
typename
CDesc
,
typename
ABlockSliceLengths
,
typename
BBlockSliceLengths
,
typename
CBlockSliceLengths
,
typename
AThreadSliceLength
,
typename
BThreadSliceLength
,
ck
::
index_t
AThreadLoopOverDim
,
// thread slice loop over on block slice. 1d is enough for
// now
ck
::
index_t
BThreadLoopOverDim
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
...
@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static
constexpr
index_t
nDimA
=
ABlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimA
=
ABlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimB
=
BBlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimB
=
BBlockDesc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
C
Block
Desc
::
GetNumOfDimension
();
static
constexpr
index_t
nDimC
=
CDesc
::
GetNumOfDimension
();
using
IndexA
=
MultiIndex
<
nDimA
>
;
using
IndexA
=
MultiIndex
<
nDimA
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CBlockDesc
{},
IndexC
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
constexpr
auto
GetLeadingElement
(
const
TensorDesc
&
desc
)
constexpr
auto
GetLeadingElement
(
const
TensorDesc
&
desc
)
...
@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBlockBuffer
>
ck
::
index_t
GetALeadingElement
(
const
ABlockDesc
&
a_block_desc
)
const
{
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetBLeadingElement
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetCLeadingElement
(
const
CDesc
&
c_desc
)
const
{
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
ck
::
index_t
GetMPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetKPerBlock
(
const
ABlockDesc
&
a_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// M * K
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// K * M
return
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
}
ck
::
index_t
GetNPerBlock
(
const
BBlockDesc
&
b_block_desc
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
}
}
ck
::
index_t
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
i_m
*
a_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
return
i_m
;
}
}
ck
::
index_t
GetBBlockStartOffset
(
const
BBlockDesc
&
b_block_desc
,
const
index_t
,
const
index_t
i_n
)
const
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
i_n
;
}
else
{
// N/8 * K * 8
return
i_n
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
ck
::
index_t
GetCBlockStartOffset
(
const
CDesc
&
c_desc
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
return
i_m
*
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CBuffer
>
void
Run
(
const
ABlockDesc
&
a_block_desc
,
void
Run
(
const
ABlockDesc
&
a_block_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
ABlockBuffer
&
a_block_buf
,
const
IndexA
&
a_origin
,
const
IndexA
&
/*
a_origin
*/
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
IndexB
&
b_origin
,
const
IndexB
&
/* b_origin */
,
const
CBlockDesc
&
c_block_desc
,
CBlockBuffer
&
c_block_buf
,
const
IndexC
&
c_origin
)
const
{
constexpr
auto
m_n_block_length
=
ck
::
Sequence
<
ABlockSliceLengths
::
At
(
AThreadLoopOverDim
),
BBlockSliceLengths
::
At
(
BThreadLoopOverDim
)
>
{};
constexpr
auto
m_n_thread_length
=
ck
::
Sequence
<
AThreadSliceLength
::
At
(
AThreadLoopOverDim
),
BThreadSliceLength
::
At
(
BThreadLoopOverDim
)
>
{};
constexpr
auto
m_n_access_length
=
m_n_block_length
/
m_n_thread_length
;
const
CDesc
&
c_desc
,
CBuffer
&
c_buf
,
const
IndexC
&
/* c_origin */
,
constexpr
auto
ordered_m_n_access_length
=
bool
is_accumulate_c
=
true
)
const
container_reorder_given_new2old
(
m_n_access_length
,
ThreadMNAccessOrder
{});
{
auto
lda
=
GetALeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
constexpr
auto
a_block_idx_zeros
=
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
typename
uniform_sequence_gen
<
nDimA
,
0
>::
type
{};
// starting point of the block
constexpr
auto
b_block_idx_zeros
=
typename
uniform_sequence_gen
<
nDimB
,
0
>::
type
{};
constexpr
auto
lda
=
GetLeadingElement
(
a_block_desc
)
*
sizeof
(
FloatA
);
const
auto
k_per_block
=
GetKPerBlock
(
a_block_desc
);
constexpr
auto
ldb
=
GetLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
const
auto
m_per_block
=
GetMPerBlock
(
a_block_desc
);
constexpr
auto
ldc
=
GetLeadingElement
(
c_block_desc
)
*
sizeof
(
FloatC
);
const
auto
n_per_block
=
GetNPerBlock
(
b_block_desc
);
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
param
.
Kr
=
KPerB
lock
;
param
.
Kr
=
k_per_b
lock
;
param
.
lda
=
lda
;
param
.
lda
=
lda
;
param
.
ldb
=
ldb
;
param
.
ldb
=
ldb
;
param
.
ldc
=
ldc
;
param
.
ldc
=
ldc
;
param
.
alpha
=
1.0
f
;
// TODO
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
static_ford
<
decltype
(
ordered_m_n_access_length
)
>
{}([
&
](
auto
ordered_idx
)
{
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
constexpr
auto
origin_m_n_idx
=
ordered_idx
.
ReorderGivenOld2New
(
ThreadMNAccessOrder
{});
{
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
constexpr
auto
current_m_idx
=
{
origin_m_n_idx
.
At
(
0
)
*
AThreadSliceLength
::
At
(
AThreadLoopOverDim
);
auto
current_mr
=
ck
::
math
::
min
(
m_per_block
-
i_m
,
m_per_thread
);
constexpr
auto
current_n_idx
=
param
.
p_a
=
&
a_block_buf
.
p_data_
[
GetABlockStartOffset
(
a_block_desc
,
i_m
,
0
)];
origin_m_n_idx
.
At
(
1
)
*
BThreadSliceLength
::
At
(
BThreadLoopOverDim
);
constexpr
auto
current_mr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
0
)
-
current_m_idx
,
m_n_thread_length
.
At
(
0
));
constexpr
auto
current_nr
=
ck
::
math
::
min
(
m_n_block_length
.
At
(
1
)
-
current_n_idx
,
m_n_thread_length
.
At
(
1
));
constexpr
auto
a_block_idx
=
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
a_block_idx_zeros
.
Modify
(
AThreadLoopOverDim
,
current_m_idx
);
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr
auto
a_block_coord
=
make_tensor_coordinate
(
a_block_desc
,
to_multi_index
(
a_origin
+
a_block_idx
));
constexpr
auto
b_block_idx
=
for
(
ck
::
index_t
i_n
=
0
;
i_n
<
n_per_block
;
i_n
+=
n_per_thread
)
b_block_idx_zeros
.
Modify
(
BThreadLoopOverDim
,
current_n_idx
);
{
constexpr
auto
b_block_coord
=
auto
current_nr
=
ck
::
math
::
min
(
n_per_block
-
i_n
,
n_per_thread
);
make_tensor_coordinate
(
b_block_desc
,
to_multi_index
(
b_origin
+
b_block_idx
));
constexpr
auto
c_block_coord
=
param
.
p_b
=
&
b_block_buf
.
p_data_
[
GetBBlockStartOffset
(
b_block_desc
,
0
,
i_n
)];
make_tensor_coordinate
(
c_block_desc
,
to_multi_index
(
c_origin
+
origin_m_n_idx
))
;
param
.
p_c
=
&
c_buf
.
p_data_
[
GetCBlockStartOffset
(
c_desc
,
i_m
,
i_n
)]
;
param
.
p_a
=
&
a_block_buf
.
p_data_
[
a_block_coord
.
GetOffset
()];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
param
.
p_b
=
&
b_block_buf
.
p_data_
[
b_block_coord
.
GetOffset
()];
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
param
.
p_c
=
&
c_block_buf
.
p_data_
[
c_block_coord
.
GetOffset
()];
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_mr
,
current_nr
);
});
}
}
}
}
}
};
};
...
...
include/ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp
View file @
afc7d431
...
@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
...
@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC
,
OddC
,
};
};
enum
ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop
,
NHWC_GemmKLoopOverC
,
// not merge c*y*x, and c % k_per_block == 0
};
enum
ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver
,
LoopOver_MNK
,
LoopOver_MKN
,
};
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
afc7d431
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -23,20 +25,21 @@ namespace device {
...
@@ -23,20 +25,21 @@ namespace device {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
>
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
// bool IsGemmMPadded
,
bool
UseALocalBuffer
,
// bool IsGemmNPadded
,
bool
UseBLocalBuffer
,
// bool IsGemmKPadded
>
bool
UseCLocalBuffer
>
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
{
...
@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
ck
::
index_t
gemm_n_padded
=
math
::
integer_least_multiple
(
gemm_n
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
const
auto
wei_gemm_n_k_grid_desc
=
const
auto
wei_gemm_n_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
const
auto
wei_gemm_
n0_k_n1
_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemm_
padn_k
_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_n_k_grid_desc
,
wei_gemm_n_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
make_tuple
(
make_right_pad_transform
(
gemm_n
,
gemm_n_padded
-
gemm_n
),
ck
::
make_tuple
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I0
)
/
make_pass_through_transform
(
gemm_k
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
wei_gemm_n0_k_n1_grid_desc
=
transform_tensor_descriptor
(
wei_gemm_padn_k_grid_desc
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I0
)
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_pass_through_transform
(
wei_gemm_n_k_grid_desc
.
GetLength
(
I1
))),
ck
::
make_pass_through_transform
(
wei_gemm_
pad
n_k_grid_desc
.
GetLength
(
I1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
...
@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
multiplies
<
ck
::
index_t
>
());
std
::
multiplies
<
ck
::
index_t
>
());
}
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
namespace
ck
;
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
K
;
const
index_t
GemmN
=
GetGemmN
(
K
)
;
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
// A:
...
@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
bool
UseCLocalBuffer
=
true
;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
InDataType
,
InDataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
<
WeiDataType
,
WeiDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
OutDataType
,
OutDataType
,
CGridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
WeiDataType
,
// WeiDataType,
WeiDataType
,
// WeiDataType,
OutDataType
,
// OutDataType,
OutDataType
,
// OutDataType,
AccDataType
,
// AccDataType,
AGridDesc
,
// AGridDesc,
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
CGridDesc
,
// CGridDesc,
...
@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock
,
// NPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ck
::
Sequence
<
0
,
1
,
2
>
,
// BlockMNKAccessOrder,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
...
@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
...
@@ -591,7 +710,10 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -591,7 +710,10 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
CElementwiseOperation
>
;
float
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
float
ave_time
=
0
;
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
...
@@ -605,7 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -605,7 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
a
_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c
_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
...
@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
}
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
// Gridwise GEMM size
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
...
@@ -749,15 +878,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -749,15 +878,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"DFwdAvx2_NHWC_KYXC"
<<
"<"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
MPerBlock
<<
", "
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
NPerBlock
<<
", "
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
KPerBlock
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
">"
;
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
afc7d431
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
afc7d431
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_
gemm_
param.hpp"
namespace
ck
{
namespace
ck
{
namespace
cpu
{
namespace
cpu
{
...
@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 4)
\n
lea (%%rdx, %%rdi, 1), %%r8
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
".if (m_Mr > 5)
\n
lea (%%r8, %%rdi, 1), %%r9
\n
.endif
\n
"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
"test %%edi, %%edi
\n
"
"je L_GemmAvx2_MxN_6x16_Store_C%=
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm2, %%ymm2
\n
.endif
\n
"
...
@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5)
\n
vaddps (%%r9), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
".if (m_Mr > 5) && (m_Nr > 8)
\n
vaddps 32(%%r9), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_6x16_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
@@ -424,6 +428,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -424,6 +428,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
};
// clang-format off
// clang-format off
if
(
param
->
accmulate_c
){
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
1
*
ldc
+
0
*
8
);
...
@@ -436,6 +441,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -436,6 +441,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
4
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
5
*
ldc
+
1
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Mr
>
1
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
2
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
3
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
4
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
4
&&
Nr
>
8
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
5
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
5
&&
Nr
>
8
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
while
(
Kr
>
4
){
#pragma unroll
#pragma unroll
...
@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
if
constexpr
(
NonTemporalStore
)
{
if
constexpr
(
NonTemporalStore
)
{
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
0
*
8
,
ymm1
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
0
*
ldc
+
1
*
8
,
ymm1
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
0
*
8
,
ymm2
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
_mm256_stream_ps
(
p_c
+
1
*
ldc
+
1
*
8
,
ymm3
);
...
@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 2)
\n
lea (%%rbx, %%rdi, 1), %%rcx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
".if (m_Mr > 3)
\n
lea (%%rcx, %%rdi, 1), %%rdx
\n
.endif
\n
"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
"mov 60(%[m_param]), %%edi
\n
"
// accmulate_c
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
"test %%edi, %%edi
\n
"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
"je L_GemmAvx2_MxN_4x24_Store_C%=
\n
"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
" vaddps (%%rax), %%ymm0, %%ymm0
\n
"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Nr > 8)
\n
vaddps 32(%%rax), %%ymm1, %%ymm1
\n
.endif
\n
"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Nr >16)
\n
vaddps 64(%%rax), %%ymm2, %%ymm2
\n
.endif
\n
"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 1)
\n
vaddps (%%rbx), %%ymm3, %%ymm3
\n
.endif
\n
"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)
\n
vaddps 32(%%rbx), %%ymm4, %%ymm4
\n
.endif
\n
"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)
\n
vaddps 64(%%rbx), %%ymm5, %%ymm5
\n
.endif
\n
"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 2)
\n
vaddps (%%rcx), %%ymm6, %%ymm6
\n
.endif
\n
"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)
\n
vaddps 32(%%rcx), %%ymm7, %%ymm7
\n
.endif
\n
"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)
\n
vaddps 64(%%rcx), %%ymm8, %%ymm8
\n
.endif
\n
"
".if (m_Mr > 3)
\n
vaddps (%%rdx), %%ymm9, %%ymm9
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr > 8)
\n
vaddps 32(%%rdx), %%ymm10, %%ymm10
\n
.endif
\n
"
".if (m_Mr > 3) && (m_Nr >16)
\n
vaddps 64(%%rdx), %%ymm11, %%ymm11
\n
.endif
\n
"
"L_GemmAvx2_MxN_4x24_Store_C%=:
\n
"
".if m_NTStore == 0
\n
"
".if m_NTStore == 0
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
" vmovups %%ymm0, (%%rax)
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
".if (m_Nr > 8)
\n
vmovups %%ymm1, 32(%%rax)
\n
.endif
\n
"
...
@@ -960,6 +984,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -960,6 +984,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
};
// clang-format off
// clang-format off
if
(
param
->
accmulate_c
)
{
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
ymm0
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
0
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
1
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_loadu_ps
(
p_c
+
0
*
ldc
+
2
*
8
);
...
@@ -972,6 +997,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -972,6 +997,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
0
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
1
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_loadu_ps
(
p_c
+
3
*
ldc
+
2
*
8
);
}
else
{
ymm0
=
_mm256_xor_ps
(
ymm0
,
ymm0
);
if
constexpr
(
Nr
>
8
)
ymm1
=
_mm256_xor_ps
(
ymm1
,
ymm1
);
if
constexpr
(
Nr
>
16
)
ymm2
=
_mm256_xor_ps
(
ymm2
,
ymm2
);
if
constexpr
(
Mr
>
1
)
ymm3
=
_mm256_xor_ps
(
ymm3
,
ymm3
);
if
constexpr
(
Mr
>
1
&&
Nr
>
8
)
ymm4
=
_mm256_xor_ps
(
ymm4
,
ymm4
);
if
constexpr
(
Mr
>
1
&&
Nr
>
16
)
ymm5
=
_mm256_xor_ps
(
ymm5
,
ymm5
);
if
constexpr
(
Mr
>
2
)
ymm6
=
_mm256_xor_ps
(
ymm6
,
ymm6
);
if
constexpr
(
Mr
>
2
&&
Nr
>
8
)
ymm7
=
_mm256_xor_ps
(
ymm7
,
ymm7
);
if
constexpr
(
Mr
>
2
&&
Nr
>
16
)
ymm8
=
_mm256_xor_ps
(
ymm8
,
ymm8
);
if
constexpr
(
Mr
>
3
)
ymm9
=
_mm256_xor_ps
(
ymm9
,
ymm9
);
if
constexpr
(
Mr
>
3
&&
Nr
>
8
)
ymm10
=
_mm256_xor_ps
(
ymm10
,
ymm10
);
if
constexpr
(
Mr
>
3
&&
Nr
>
16
)
ymm11
=
_mm256_xor_ps
(
ymm11
,
ymm11
);
}
while
(
Kr
>
4
){
while
(
Kr
>
4
){
#pragma unroll
#pragma unroll
...
@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
...
@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
6
][
2
]
=
{
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
6
][
2
]
=
{
{
{
ThreadwiseGemm_
6x16
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
1x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
5x16
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
2x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
4x16
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
3x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
3x16
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
4x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
2x16
_t
::
Run
,
ThreadwiseGemm_
5x8
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
5x16
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
1x16
_t
::
Run
,
ThreadwiseGemm_
6x8
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
6x16
_t
::
Run
,
},
},
};
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
{
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
5
&&
in
>=
0
&&
in
<=
1
);
return
dispatch_table
[
mr
][
nr
](
param
);
return
dispatch_table
[
mr
][
nr
](
param
);
}
}
};
};
...
@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
...
@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
4
][
3
]
=
{
static
constexpr
pThreadwiseGemmAvx2Run
dispatch_table
[
4
][
3
]
=
{
{
{
ThreadwiseGemm_
4x24
_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
1x24
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
3x24
_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
2x24
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
2x24
_t
::
Run
,
ThreadwiseGemm_
3x8
_t
::
Run
,
ThreadwiseGemm_
2
x16_t
::
Run
,
ThreadwiseGemm_
3
x16_t
::
Run
,
ThreadwiseGemm_
2x8
_t
::
Run
,
ThreadwiseGemm_
3x24
_t
::
Run
,
},
},
{
{
ThreadwiseGemm_
1x24
_t
::
Run
,
ThreadwiseGemm_
4x8
_t
::
Run
,
ThreadwiseGemm_
1
x16_t
::
Run
,
ThreadwiseGemm_
4
x16_t
::
Run
,
ThreadwiseGemm_
1x8
_t
::
Run
,
ThreadwiseGemm_
4x24
_t
::
Run
,
},
},
};
};
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
static
void
Run
(
ThreadwiseGemmParam
*
param
,
index_t
mr
,
index_t
nr
)
{
{
return
dispatch_table
[
mr
][
nr
](
param
);
index_t
im
=
mr
-
1
;
index_t
in
=
(
nr
>>
3
)
-
1
;
assert
(
im
>=
0
&&
im
<=
3
&&
in
>=
0
&&
in
<=
2
);
return
dispatch_table
[
im
][
in
](
param
);
}
}
};
};
...
...
include/ck/tensor_operation/cpu/thread/threadwise_param.hpp
→
include/ck/tensor_operation/cpu/thread/threadwise_
gemm_
param.hpp
View file @
afc7d431
#ifndef CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_
GEMM_
PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_
GEMM_
PARAM_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "math.hpp"
#include "math.hpp"
...
@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
...
@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t
ldb
;
// in unit of byte
uint64_t
ldb
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
uint64_t
ldc
;
// in unit of byte
float
alpha
;
float
alpha
;
uint32_t
_pack0
;
int
accmulate_c
;
// if 1, need load C and add into current fma. if 0, direct store out c result
}
__attribute__
((
packed
));
}
__attribute__
((
packed
));
}
// namespace cpu
}
// namespace cpu
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp
View file @
afc7d431
...
@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
...
@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
"wrong! cannot evenly divide"
);
int
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
int
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
int
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
int
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
int
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
int
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
int
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
int
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
int
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
int
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
int
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
printf
(
"N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d
\n
"
,
N
,
Hi
,
Wi
,
C
,
Ho
,
Wo
,
Fy
,
Fx
,
Dy
,
Sy
,
Dx
,
Sx
,
Py
,
Px
);
}
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
...
@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
// std::cout<<"num_access:"<<num_access<<std::endl;
std
::
cout
<<
"src hidden:"
<<
SrcDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
std
::
cout
<<
"dst hidden:"
<<
DstDesc
::
GetNumOfHiddenDimension
()
<<
std
::
endl
;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using src_vector_t = typename src_vector_type::type;
...
@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
...
@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
}
});
});
#endif
const
auto
src_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
src_slice_step
=
make_tensor_coordinate_step
(
src_desc
,
to_multi_index
(
src_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
const
auto
dst_slice_idx_zeros
=
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
const
auto
dst_slice_step
=
make_tensor_coordinate_step
(
dst_desc
,
to_multi_index
(
dst_slice_idx_zeros
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
1
>
{})));
for
(
auto
idx_id
=
0
;
idx_id
<
num_access
;
idx_id
++
)
{
using
src_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
dst_vector_type
=
ck
::
cpu
::
vector_type_maker_t
<
DstData
,
ScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
printf
(
"[%s] "
,
is_src_valid
?
"y"
:
"n"
);
print_multi_index
(
src_coord_
.
GetIndex
());
printf
(
"----"
);
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
auto
dst_vector_container
=
dst_vector_type
{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_
(
dst_vector_container
.
template
AsType
<
dst_vector_t
>(),
src_vector_container
.
template
AsType
<
src_vector_t
>());
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
printf
(
" -> "
);
print_multi_index
(
dst_coord_
.
GetIndex
());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf
(
"
\n
"
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>());
// move coordinate
if
(
idx_id
!=
num_access
-
1
)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_slice_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_slice_step
);
}
}
// move coordinate back to slice origin (or not)
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
0 → 100644
View file @
afc7d431
This diff is collapsed.
Click to expand it.
library/include/ck/library/host_tensor/device.hpp
View file @
afc7d431
...
@@ -121,6 +121,10 @@ template <typename... Args, typename F>
...
@@ -121,6 +121,10 @@ template <typename... Args, typename F>
float
launch_and_time_cpu_kernel
(
F
kernel
,
int
nrepeat
,
Args
...
args
)
float
launch_and_time_cpu_kernel
(
F
kernel
,
int
nrepeat
,
Args
...
args
)
{
{
WallTimer
timer
;
WallTimer
timer
;
int
nwarmup
=
3
;
for
(
int
i
=
0
;
i
<
nwarmup
;
i
++
)
kernel
(
args
...);
kernel
(
args
...);
timer
.
Start
();
timer
.
Start
();
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
afc7d431
...
@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
...
@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
P
assThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
P
T
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
WeiType
,
...
@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
...
@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static
constexpr
auto
ConvFwd1x1S1P0
=
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
// clang-format off
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
120
,
64
,
4
,
24
,
true
,
true
,
false
),
float
,
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
144
,
128
,
4
,
24
,
true
,
true
,
false
),
float
,
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
float
,
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
float
,
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
288
,
128
,
4
,
24
,
true
,
true
,
false
)
>
;
PassThrough
,
// clang-format on
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
256
,
128
,
64
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
512
,
256
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>
,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
2
,
1024
,
144
,
128
,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
>>
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
afc7d431
...
@@ -9,8 +9,11 @@
...
@@ -9,8 +9,11 @@
#include "reference_conv_fwd.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -18,6 +21,8 @@ namespace cpu {
...
@@ -18,6 +21,8 @@ namespace cpu {
namespace
device
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
namespace
device_conv2d_fwd_avx2_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
...
@@ -34,6 +39,7 @@ using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
...
@@ -34,6 +39,7 @@ using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template
<
typename
T
>
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
{
int
error_count
=
0
;
float
max_diff
=
1e-6
;
float
max_diff
=
1e-6
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
...
@@ -41,28 +47,35 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -41,28 +47,35 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
if
(
max_diff
<
diff
)
{
{
return
false
;
error_count
++
;
printf
(
"idx:%3d, ref:%f, res:%f (diff:%f)
\n
"
,
i
,
double
(
ref
.
mData
[
i
]),
double
(
result
.
mData
[
i
]),
diff
);
}
}
}
}
return
true
;
return
error_count
==
0
;
}
}
float
calculate_gflops
()
{}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
int
data_type
=
0
;
int
data_type
=
0
;
int
init_method
=
0
;
int
init_method
=
0
;
// Conv shape
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
2
;
ck
::
index_t
K
=
256
;
ck
::
index_t
K
=
256
;
ck
::
index_t
C
=
192
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
conv_stride_h
=
2
;
ck
::
index_t
conv_stride_h
=
1
;
ck
::
index_t
conv_stride_w
=
2
;
ck
::
index_t
conv_stride_w
=
1
;
ck
::
index_t
conv_dilation_h
=
1
;
ck
::
index_t
conv_dilation_h
=
1
;
ck
::
index_t
conv_dilation_w
=
1
;
ck
::
index_t
conv_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
...
@@ -72,7 +85,7 @@ int main(int argc, char* argv[])
...
@@ -72,7 +85,7 @@ int main(int argc, char* argv[])
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
data_type
=
1
;
data_type
=
0
;
init_method
=
1
;
init_method
=
1
;
}
}
else
if
(
argc
==
3
)
else
if
(
argc
==
3
)
...
@@ -110,17 +123,14 @@ int main(int argc, char* argv[])
...
@@ -110,17 +123,14 @@ int main(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
,
auto
acc_type
)
{
auto
Run
=
[
&
](
auto
input_type
,
auto
wei_type
,
auto
out_type
)
{
using
InDataType
=
decltype
(
input_type
);
using
InDataType
=
decltype
(
input_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
OutDataType
=
decltype
(
out_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
ReferenceConvBwdInstance
=
using
ReferenceConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
InDataType
,
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
>
;
OutElementOp
>
;
...
@@ -147,53 +157,93 @@ int main(int argc, char* argv[])
...
@@ -147,53 +157,93 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
std
::
vector
<
std
::
size_t
>
({
C_
*
H_
*
W_
,
1
,
W_
*
C_
,
C_
}));
};
};
Tensor
<
Out
DataType
>
out_n_ho_wo_k
(
f_host_tensor_descriptor
(
N
,
K
,
H
o
,
W
o
));
Tensor
<
In
DataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
H
i
,
W
i
));
Tensor
<
WeiDataType
>
wei_k_y_x
_c
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
WeiDataType
>
wei_k_
c_
y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
In
DataType
>
in_n
_h
i
_w
i_c
_host_result
(
f_host_tensor_descriptor
(
N
,
C
,
H
i
,
W
i
));
Tensor
<
Out
DataType
>
out_n_k
_h
o
_w
o
_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
H
o
,
W
o
));
Tensor
<
In
DataType
>
in_n
_h
i
_w
i_c
_device_result
(
f_host_tensor_descriptor
(
N
,
C
,
H
i
,
W
i
));
Tensor
<
Out
DataType
>
out_n_k
_h
o
_w
o
_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
H
o
,
W
o
));
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_hi_wi_c_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in (N, C, Hi, Wi): "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_y_x_c
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei(K, C, Y, X): "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_ho_wo_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out(N, K, Ho, Wo): "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"LPad(H, W):"
<<
in_left_pad_h
<<
","
<<
in_left_pad_w
<<
", RPad(H, W):"
<<
in_right_pad_h
<<
","
<<
in_right_pad_w
<<
", Stride(H, W):"
<<
conv_stride_h
<<
", "
<<
conv_stride_w
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break
;
break
;
case
2
:
case
2
:
out_n_ho_wo_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
wei_k_y_x_c
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{});
break
;
case
3
:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for
(
auto
i_n
=
0
;
i_n
<
N
;
i_n
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
{
for
(
auto
i_hi
=
0
;
i_hi
<
Hi
;
i_hi
++
)
{
for
(
auto
i_wi
=
0
;
i_wi
<
Wi
;
i_wi
++
)
{
uint32_t
v
=
PACK_32
(
i_n
,
i_c
,
i_hi
,
i_wi
);
in_n_c_hi_wi
(
i_n
,
i_c
,
i_hi
,
i_wi
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
}
}
}
}
for
(
auto
i_k
=
0
;
i_k
<
K
;
i_k
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
C
;
i_c
++
)
{
for
(
auto
i_y
=
0
;
i_y
<
Y
;
i_y
++
)
{
for
(
auto
i_x
=
0
;
i_x
<
X
;
i_x
++
)
{
uint32_t
v
=
PACK_32
(
i_k
,
i_c
,
i_y
,
i_x
);
wei_k_c_y_x
(
i_k
,
i_c
,
i_y
,
i_x
)
=
*
reinterpret_cast
<
float
*>
(
&
v
);
}
}
}
}
break
;
break
;
default:
default:
out_n
_h
o
_w
o_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
Out
DataType
>
{
1
});
in_n_c
_h
i
_w
i
.
GenerateTensorValue
(
GeneratorTensor_
3
<
In
DataType
>
{
0
,
1
});
wei_k_y_x
_c
.
GenerateTensorValue
(
GeneratorTensor_
1
<
WeiDataType
>
{
1
});
wei_k_
c_
y_x
.
GenerateTensorValue
(
GeneratorTensor_
3
<
WeiDataType
>
{
-
1
,
1
});
}
}
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
DeviceAlignedMemCPU
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
(),
in_n_hi_wi_c_device_result
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
wei_device_buf
(
DeviceAlignedMemCPU
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
DeviceAlignedMemCPU
out_device_buf
(
DeviceAlignedMemCPU
out_device_buf
(
sizeof
(
OutDataType
)
*
sizeof
(
OutDataType
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
out_n_k_ho_wo_host_result
.
mDesc
.
GetElementSpace
(),
AVX2_DATA_ALIGNMENT
);
out_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
// reset input to zero
in_n_hi_wi_c_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
0
});
in_device_buf
.
ToDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
// get host result
// get host result
{
{
auto
ref_conv
=
ReferenceConvFwdInstance
{};
auto
ref_conv
=
ReferenceConvFwdInstance
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_hi_wi
_c_host_result
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_
c_
hi_wi
,
wei_k_y_x
_c
,
wei_k_
c_
y_x
,
out_n_ho_wo_
k
,
out_n_
k_
ho_wo_
host_result
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -205,8 +255,8 @@ int main(int argc, char* argv[])
...
@@ -205,8 +255,8 @@ int main(int argc, char* argv[])
}
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
// add device Conv instances
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
...
@@ -226,6 +276,9 @@ int main(int argc, char* argv[])
...
@@ -226,6 +276,9 @@ int main(int argc, char* argv[])
// profile device Conv instances
// profile device Conv instances
bool
success
=
true
;
bool
success
=
true
;
double
fastest_kernel_time
=
std
::
numeric_limits
<
double
>::
max
();
std
::
string
fastest_kernel_name
=
""
;
double
fastest_kernel_gflops
=
0
;
for
(
auto
&
conv_ptr
:
conv_ptrs
)
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
{
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
auto
argument_ptr
=
conv_ptr
->
MakeArgumentPointer
(
...
@@ -249,18 +302,30 @@ int main(int argc, char* argv[])
...
@@ -249,18 +302,30 @@ int main(int argc, char* argv[])
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
conv_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
1
);
double
time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
10
);
double
total_flop
=
static_cast
<
double
>
(
2
)
*
N
*
C
*
Ho
*
Wo
*
K
*
Y
*
X
;
in_device_buf
.
FromDevice
(
in_n_hi_wi_c_device_result
.
mData
.
data
());
double
gflops
=
(
total_flop
*
1e-6
)
/
time
;
if
(
!
check_out
(
in_n_hi_wi_c_host_result
,
in_n_hi_wi_c_device_result
))
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
if
(
!
check_out
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
))
{
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
success
=
false
;
}
}
else
else
{
{
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Pass Info: "
<<
conv_ptr
->
GetTypeString
()
<<
", Time:"
<<
time
<<
"ms, Gflops:"
<<
gflops
<<
std
::
endl
;
if
(
time
<
fastest_kernel_time
)
{
fastest_kernel_time
=
time
;
fastest_kernel_name
=
conv_ptr
->
GetTypeString
();
fastest_kernel_gflops
=
gflops
;
}
}
}
}
}
else
else
...
@@ -269,25 +334,27 @@ int main(int argc, char* argv[])
...
@@ -269,25 +334,27 @@ int main(int argc, char* argv[])
}
}
}
}
if
(
success
)
if
(
fastest_kernel_time
!=
std
::
numeric_limits
<
double
>::
max
())
{
{
std
::
cout
<<
"test conv2d fwd cpu : Pass"
<<
std
::
endl
;
std
::
cout
<<
" fastest:"
<<
fastest_kernel_name
<<
", time:"
<<
fastest_kernel_time
return
0
;
<<
"ms, Gflops:"
<<
fastest_kernel_gflops
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"test conv2d fwd cpu: Fail "
<<
std
::
endl
;
return
-
1
;
}
}
return
0
;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
};
if
(
data_type
==
0
)
if
(
data_type
==
0
)
{
{
return
Run
(
F32
(),
F32
(),
F32
(),
F32
());
return
Run
(
F32
(),
F32
(),
F32
());
}
else
if
(
data_type
==
1
)
{
return
Run
(
F16
(),
F16
(),
F16
(),
F32
());
}
}
else
else
{
{
...
...
test/cpu_threadwise_transfer/cpu_threadwise_transfer.cpp
View file @
afc7d431
...
@@ -226,6 +226,8 @@ int main(int argc, char** argv)
...
@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static
constexpr
ck
::
index_t
nDim
=
static
constexpr
ck
::
index_t
nDim
=
ck
::
remove_reference_t
<
decltype
(
input_desc
)
>::
GetNumOfDimension
();
ck
::
remove_reference_t
<
decltype
(
input_desc
)
>::
GetNumOfDimension
();
input_desc
.
Print
();
auto
threadwise_transfer
=
threadwise_transfer_t
{
input_desc
,
auto
threadwise_transfer
=
threadwise_transfer_t
{
input_desc
,
ck
::
make_zero_multi_index
<
nDim
>
(),
ck
::
make_zero_multi_index
<
nDim
>
(),
input_cblock_desc
,
input_cblock_desc
,
...
...
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
afc7d431
...
@@ -321,6 +321,7 @@ void test_ukernel(ukenrel_t uk,
...
@@ -321,6 +321,7 @@ void test_ukernel(ukenrel_t uk,
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldb
=
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
n
:
k
*
8
)
*
sizeof
(
FloatB
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
ldc
=
n
*
sizeof
(
float
);
param
.
alpha
=
alpha
;
param
.
alpha
=
alpha
;
param
.
accmulate_c
=
0
;
memset
(
private_c
,
0
,
m
*
n
*
sizeof
(
float
));
memset
(
private_c
,
0
,
m
*
n
*
sizeof
(
float
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment