Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2e414b7c
Commit
2e414b7c
authored
May 16, 2022
by
carlushuang
Browse files
refactor length/index setting in gridwise gemm
parent
b134b7d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
55 deletions
+74
-55
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+73
-54
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+1
-1
No files found.
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
2e414b7c
...
@@ -128,7 +128,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -128,7 +128,7 @@ struct GridwiseGemmAvx2_MxN
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
}
static
auto
GetA
MultiIndex
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
static
auto
GetA
SliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
@@ -146,7 +146,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -146,7 +146,7 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
static
auto
GetB
MultiIndex
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetB
SliceLength
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
{
// n_per_blk should be 8x
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
...
@@ -168,11 +168,49 @@ struct GridwiseGemmAvx2_MxN
...
@@ -168,11 +168,49 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
static
auto
GetC
MultiIndex
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetC
SliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
{
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
}
}
static
auto
GetAIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_k
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
i_m
,
i_k
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
i_k
,
i_m
);
}
}
static
auto
GetBIndex
(
const
ck
::
index_t
i_k
,
const
ck
::
index_t
i_n
)
{
// i_n should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
i_k
,
i_n
);
}
else
{
// B : N/8, K, N8
return
ck
::
make_multi_index
(
i_n
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
i_k
,
i_n
%
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
}
static
auto
GetCIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_n
)
{
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
const
CGridDesc
&
c_grid_desc
)
...
@@ -260,8 +298,8 @@ struct GridwiseGemmAvx2_MxN
...
@@ -260,8 +298,8 @@ struct GridwiseGemmAvx2_MxN
//
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
{
{
auto
a_move_k_step
=
ck
::
make_multi_i
ndex
(
0
,
k_per_block
);
auto
a_move_k_step
=
GetAI
ndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_i
ndex
(
0
,
k_per_block
,
0
);
auto
b_move_k_step
=
GetBI
ndex
(
k_per_block
,
0
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
...
@@ -332,31 +370,19 @@ struct GridwiseGemmAvx2_MxN
...
@@ -332,31 +370,19 @@ struct GridwiseGemmAvx2_MxN
nc_size
=
math
::
integer_least_multiple
(
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
0
));
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
0
,
i_nc
));
b_grid_desc
,
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
i_nc
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
));
auto
c_block_desc
=
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
!
UseCLocalBuffer
)
{
// c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
// ck::make_multi_index(i_mc, i_nc));
}
else
{
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunRead
(
c_grid_desc
,
c_threadwise_copy
.
RunRead
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
c_block_desc
,
c_block_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
}
}
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
...
@@ -370,12 +396,12 @@ struct GridwiseGemmAvx2_MxN
...
@@ -370,12 +396,12 @@ struct GridwiseGemmAvx2_MxN
a_grid_buf
,
a_grid_buf
,
a_block_desc
,
a_block_desc
,
a_block_buf
,
a_block_buf
,
GetA
MultiIndex
(
mc_size
,
kc_size
));
GetA
SliceLength
(
mc_size
,
kc_size
));
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
b_grid_buf
,
b_block_desc
,
b_block_desc
,
b_block_buf
,
b_block_buf
,
GetB
MultiIndex
(
kc_size
,
nc_size
));
GetB
SliceLength
(
kc_size
,
nc_size
));
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
a_block_buf
,
...
@@ -395,25 +421,19 @@ struct GridwiseGemmAvx2_MxN
...
@@ -395,25 +421,19 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
// if constexpr(UseCLocalBuffer)
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_block_buf
,
c_grid_desc
,
c_grid_desc
,
c_grid_buf
,
c_grid_buf
,
GetC
MultiIndex
(
mc_size
,
nc_size
));
GetC
SliceLength
(
mc_size
,
nc_size
));
}
}
}
}
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
{
{
auto
a_move_k_step
=
ck
::
make_multi_index
(
0
,
k_per_block
);
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
ck
::
make_multi_index
(
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
math
::
integer_divide_ceil
(
n_per_block
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
0
,
0
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
const
ck
::
index_t
grid_m_per_thread
=
math
::
integer_divide_ceil
(
grid_m
,
total_threads
);
const
ck
::
index_t
grid_m_per_thread
=
math
::
integer_divide_ceil
(
grid_m
,
total_threads
);
...
@@ -472,7 +492,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -472,7 +492,7 @@ struct GridwiseGemmAvx2_MxN
if
(
i_mc
>=
GemmM
)
if
(
i_mc
>=
GemmM
)
break
;
break
;
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
ck
::
make_multi_i
ndex
(
i_mc
,
0
));
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAI
ndex
(
i_mc
,
0
));
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
{
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
...
@@ -482,10 +502,9 @@ struct GridwiseGemmAvx2_MxN
...
@@ -482,10 +502,9 @@ struct GridwiseGemmAvx2_MxN
a_grid_buf
,
a_grid_buf
,
a_block_desc
,
a_block_desc
,
a_block_buf
,
a_block_buf
,
GetA
MultiIndex
(
mc_size
,
kc_size
));
GetA
SliceLength
(
mc_size
,
kc_size
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
i_kc
,
0
));
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
// TODO: if use local C buffer, then this nc loop need to loop only once
// TODO: if use local C buffer, then this nc loop need to loop only once
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
...
@@ -500,7 +519,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -500,7 +519,7 @@ struct GridwiseGemmAvx2_MxN
b_grid_buf
,
b_grid_buf
,
b_block_desc
,
b_block_desc
,
b_block_buf
,
b_block_buf
,
GetB
MultiIndex
(
kc_size
,
nc_size
));
GetB
SliceLength
(
kc_size
,
nc_size
));
auto
c_block_desc
=
UseCLocalBuffer
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
...
@@ -508,13 +527,13 @@ struct GridwiseGemmAvx2_MxN
...
@@ -508,13 +527,13 @@ struct GridwiseGemmAvx2_MxN
if
constexpr
(
!
UseCLocalBuffer
)
if
constexpr
(
!
UseCLocalBuffer
)
{
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunRead
(
c_block_desc
,
c_threadwise_copy
.
RunRead
(
c_grid_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
c_block_desc
,
c_block_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
}
}
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
...
@@ -535,14 +554,14 @@ struct GridwiseGemmAvx2_MxN
...
@@ -535,14 +554,14 @@ struct GridwiseGemmAvx2_MxN
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
c_grid_desc
,
ck
::
make_multi_i
ndex
(
i_mc
,
i_nc
));
GetCI
ndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_block_buf
,
c_grid_desc
,
c_grid_desc
,
c_grid_buf
,
c_grid_buf
,
GetC
MultiIndex
(
mc_size
,
nc_size
));
GetC
SliceLength
(
mc_size
,
nc_size
));
}
}
else
else
{
{
...
@@ -550,14 +569,14 @@ struct GridwiseGemmAvx2_MxN
...
@@ -550,14 +569,14 @@ struct GridwiseGemmAvx2_MxN
// elementwise op from global to global
// elementwise op from global to global
if
((
i_kc
+
k_per_block
)
>=
GemmK
)
if
((
i_kc
+
k_per_block
)
>=
GemmK
)
{
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
c_grid_desc
,
ck
::
make_multi_i
ndex
(
i_mc
,
i_nc
));
GetCI
ndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_block_buf
,
c_grid_desc
,
c_grid_desc
,
c_grid_buf
,
c_grid_buf
,
GetC
MultiIndex
(
mc_size
,
nc_size
));
GetC
SliceLength
(
mc_size
,
nc_size
));
}
}
}
}
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
2e414b7c
...
@@ -985,7 +985,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -985,7 +985,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
src
_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst
_buf
.
p_data_
)
+
src_offset
;
dst
_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src
_buf
.
p_data_
)
+
src_offset
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment