Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b8442b51
"symphony/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "238679917e522273f25b71989c6f486111b0b8b7"
Commit
b8442b51
authored
Sep 01, 2021
by
ltqin
Browse files
add a matrix unmerge
parent
5efcb64b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
9 deletions
+32
-9
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+26
-7
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+6
-2
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
b8442b51
...
@@ -199,13 +199,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -199,13 +199,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
CalculateKBatch
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
)
CalculateKBatch
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
)
{
{
constexpr
auto
MAX_GRID
=
2048
;
constexpr
auto
MAX_GRID
=
2048
;
const
index_t
grid_size
=
CalculateGridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size
_mn
=
Calculate
MN
GridSize
(
c_m_n_grid_desc
);
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
auto
batch
=
K0
/
KPerBlock
;
auto
batch
=
K0
/
KPerBlock
;
assert
(
K0
%
KPerBlock
==
0
);
assert
(
K0
%
KPerBlock
==
0
);
index_t
div
=
1
;
index_t
div
=
1
;
while
(
batch
*
grid_size
>
MAX_GRID
&&
batch
>
div
)
while
(
batch
*
grid_size
_mn
>
MAX_GRID
&&
batch
>
div
)
{
{
div
++
;
div
++
;
if
(
batch
%
div
==
0
)
if
(
batch
%
div
==
0
)
...
@@ -217,16 +217,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -217,16 +217,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
index_t
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
Calculate
MN
GridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
grid_size
_mn
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
return
grid_size
_mn
;
}
}
__host__
__device__
static
constexpr
auto
MakeABK0MK1GridDescriptor
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
index_t
kbatch
)
{
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
a_b_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
kbatch
,
K0
/
kbatch
)),
make_pass_through_transform
(
M
),
make_pass_through_transform
(
K1Value
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
a_b_k0_m_k1_grid_desc
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
...
@@ -298,6 +313,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -298,6 +313,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const
auto
kbatch
=
CalculateKBatch
(
CMNGridDesc
{},
b_k0_n_k1_grid_desc
);
if
(
get_block_1d_id
()
==
0
)
printf
(
"*****kbatch : %d"
,
kbatch
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
// divide block work by [M, N]
// divide block work by [M, N]
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
b8442b51
...
@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
auto
kbatch
=
GridwiseGemm
::
CalculateKBatch
(
c_m_n_grid_desc
,
b_k0_n_k1_grid_desc
);
const
auto
kbatch
=
GridwiseGemm
::
CalculateKBatch
(
c_m_n_grid_desc
,
b_k0_n_k1_grid_desc
);
const
auto
a_b_k0_m_k1_grid_desc
=
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
,
kbatch
);
{
{
std
::
cout
<<
"k batch number is: "
<<
kbatch
<<
std
::
endl
;
std
::
cout
<<
"k batch number is: "
<<
kbatch
<<
std
::
endl
;
}
}
...
@@ -135,13 +138,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -135,13 +138,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
ABK0MK1GridDesc
=
decltype
(
a_b_k0_m_k1_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size_mn
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size_mn
=
GridwiseGemm
::
Calculate
MN
GridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size
=
grid_size_mn
*
kbatch
;
const
index_t
grid_size
=
grid_size_mn
*
kbatch
;
{
{
std
::
cout
<<
"mxn gridSize : "
<<
grid_size_mn
<<
" finally grid_size : "
<<
grid_size
std
::
cout
<<
"mxn gridSize : "
<<
grid_size_mn
<<
" finally grid_size : "
<<
grid_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment