Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d6d37ea9
Commit
d6d37ea9
authored
May 16, 2022
by
carlushuang
Browse files
refactor Run to use slice length as block size. Fix a bug in general input copy
parent
2e414b7c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
43 deletions
+41
-43
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+23
-19
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+18
-24
No files found.
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
d6d37ea9
...
...
@@ -123,10 +123,17 @@ struct GridwiseGemmAvx2_MxN
}
}
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
CGridDesc
&
c_grid_desc
)
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
else
return
c_grid_desc
;
}
static
auto
GetASliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
...
...
@@ -269,7 +276,7 @@ struct GridwiseGemmAvx2_MxN
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
)),
// CBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
...
...
@@ -325,7 +332,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation
{});
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
...
...
@@ -373,8 +380,7 @@ struct GridwiseGemmAvx2_MxN
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
0
,
i_nc
));
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
if
constexpr
(
!
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
...
...
@@ -456,7 +462,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation
{});
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
...
...
@@ -521,9 +527,7 @@ struct GridwiseGemmAvx2_MxN
b_block_buf
,
GetBSliceLength
(
kc_size
,
nc_size
));
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
if
constexpr
(
!
UseCLocalBuffer
)
{
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
d6d37ea9
...
...
@@ -368,15 +368,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
if
constexpr
(
BypassTransfer
)
{
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
p_src
;
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
m_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
slice_length
[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
...
...
@@ -540,19 +537,23 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck
::
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>
0
)
{
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
ck
::
index_t
current_k_block_along_c
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
i_k_itr
);
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
,
element_op_
);
p_dst_k
,
p_src_k
,
current_k_block
_along_c
,
element_op_
);
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
_along_c
);
p_dst_k
+=
current_k_block
;
p_src_k
+=
current_k_block
;
p_dst_k
+=
current_k_block
_along_c
;
p_src_k
+=
current_k_block
_along_c
;
i_c_itr_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
_along_c
;
if
(
i_c_itr_k
>=
C
)
{
i_c_itr_k
=
0
;
...
...
@@ -569,7 +570,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
p_src_k
+=
input_offset_ovf_x_acc_y
;
}
i_k_itr
-=
current_k_block
;
i_k_itr
-=
current_k_block
_along_c
;
}
/*** go along Gemm K ***/
...
...
@@ -765,11 +766,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
}
else
{
const
ck
::
index_t
n_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
0
>
{}]
*
slice_length
[
Number
<
2
>
{}];
const
ck
::
index_t
k_per_block
=
slice_length
[
Number
<
1
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
...
...
@@ -1002,7 +1000,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if
constexpr
(
!
std
::
is_same
<
ElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
// if (true) {
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
...
...
@@ -1073,11 +1070,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
}
else
{
const
ck
::
index_t
m_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
const
ck
::
index_t
n_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment