Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d6d37ea9
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "52663284cfe1eb250859f9fc8e84df5e7f093cce"
Commit
d6d37ea9
authored
May 16, 2022
by
carlushuang
Browse files
refactor Run to use slice length as block size. Fix a bug in general input copy
parent
2e414b7c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
43 deletions
+41
-43
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+23
-19
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+18
-24
No files found.
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
d6d37ea9
...
@@ -123,9 +123,16 @@ struct GridwiseGemmAvx2_MxN
...
@@ -123,9 +123,16 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
CGridDesc
&
c_grid_desc
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
else
return
c_grid_desc
;
}
}
static
auto
GetASliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
static
auto
GetASliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
...
@@ -264,16 +271,16 @@ struct GridwiseGemmAvx2_MxN
...
@@ -264,16 +271,16 @@ struct GridwiseGemmAvx2_MxN
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
FloatA
,
// FloatA,
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
)),
// CBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// gemm MN to utilize micro kernel>{};
int
total_threads
=
omp_get_max_threads
();
int
total_threads
=
omp_get_max_threads
();
...
@@ -325,7 +332,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -325,7 +332,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation
{});
BElementwiseOperation
{});
auto
c_threadwise_copy
=
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
make_zero_multi_index
<
2
>
(),
...
@@ -373,8 +380,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -373,8 +380,7 @@ struct GridwiseGemmAvx2_MxN
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
0
,
i_nc
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
0
,
i_nc
));
auto
c_block_desc
=
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
!
UseCLocalBuffer
)
if
constexpr
(
!
UseCLocalBuffer
)
{
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
...
@@ -456,7 +462,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -456,7 +462,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation
{});
BElementwiseOperation
{});
auto
c_threadwise_copy
=
auto
c_threadwise_copy
=
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
),
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
c_grid_desc
,
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
make_zero_multi_index
<
2
>
(),
...
@@ -521,9 +527,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -521,9 +527,7 @@ struct GridwiseGemmAvx2_MxN
b_block_buf
,
b_block_buf
,
GetBSliceLength
(
kc_size
,
nc_size
));
GetBSliceLength
(
kc_size
,
nc_size
));
auto
c_block_desc
=
UseCLocalBuffer
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
!
UseCLocalBuffer
)
if
constexpr
(
!
UseCLocalBuffer
)
{
{
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
d6d37ea9
...
@@ -368,15 +368,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -368,15 +368,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
p_src
;
}
}
else
else
{
{
const
ck
::
index_t
m_per_block
=
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
slice_length
[
Number
<
1
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
...
@@ -540,19 +537,23 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -540,19 +537,23 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck
::
index_t
i_k_itr
=
k_per_block
;
ck
::
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>
0
)
while
(
i_k_itr
>
0
)
{
{
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
ck
::
index_t
current_k_block_along_c
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
i_k_itr
);
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
,
element_op_
);
p_dst_k
,
p_src_k
,
current_k_block
_along_c
,
element_op_
);
else
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
_along_c
);
p_dst_k
+=
current_k_block
;
p_dst_k
+=
current_k_block
_along_c
;
p_src_k
+=
current_k_block
;
p_src_k
+=
current_k_block
_along_c
;
i_c_itr_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
_along_c
;
if
(
i_c_itr_k
>=
C
)
if
(
i_c_itr_k
>=
C
)
{
{
i_c_itr_k
=
0
;
i_c_itr_k
=
0
;
...
@@ -569,7 +570,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -569,7 +570,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
p_src_k
+=
input_offset_ovf_x_acc_y
;
p_src_k
+=
input_offset_ovf_x_acc_y
;
}
}
i_k_itr
-=
current_k_block
;
i_k_itr
-=
current_k_block
_along_c
;
}
}
/*** go along Gemm K ***/
/*** go along Gemm K ***/
...
@@ -765,11 +766,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
...
@@ -765,11 +766,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
}
}
else
else
{
{
const
ck
::
index_t
n_per_block
=
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
0
>
{}]
*
slice_length
[
Number
<
2
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
const
ck
::
index_t
k_per_block
=
slice_length
[
Number
<
1
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// dst_desc.GetTransforms()[Number<0>{}]
...
@@ -1002,7 +1000,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -1002,7 +1000,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if
constexpr
(
!
std
::
is_same
<
ElementwiseOperation
,
if
constexpr
(
!
std
::
is_same
<
ElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
{
// if (true) {
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
...
@@ -1073,11 +1070,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -1073,11 +1070,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
}
}
else
else
{
{
const
ck
::
index_t
m_per_block
=
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
const
ck
::
index_t
n_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment