Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
72f3eb67
"docs/vscode:/vscode.git/clone" did not exist on "68a35543d5ab91722babf1d26105a5c4eda46a41"
Commit
72f3eb67
authored
Sep 01, 2021
by
ltqin
Browse files
add split k functiion
parent
1d4f5453
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
2 deletions
+29
-2
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+24
-0
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+5
-2
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
72f3eb67
...
...
@@ -195,6 +195,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
);
}
__host__
__device__
static
constexpr
index_t
CalculateKBatch
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
constexpr
auto
MAX_GRID
=
2048
;
auto
batch
=
K0
/
KPerBlock
;
assert
(
K0
%
KPerBlock
==
0
);
index_t
div
=
1
;
while
(
batch
*
grid_size
>
MAX_GRID
&&
batch
>
div
)
{
div
++
;
if
(
batch
%
div
==
0
)
batch
=
batch
/
div
;
}
batch
=
std
::
max
(
1
,
batch
);
return
batch
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
72f3eb67
...
...
@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
auto
kbatch
=
GridwiseGemm
::
CalculateKBatch
(
c_m_n_grid_desc
,
b_k0_n_k1_grid_desc
);
{
std
::
cout
<<
"k batch number is: "
<<
kbatch
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
{
throw
std
::
runtime_error
(
...
...
@@ -138,7 +141,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
)
*
kbatch
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment