Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
34520806
Commit
34520806
authored
Oct 16, 2022
by
Paul
Browse files
Format
parent
2ca29096
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
26 deletions
+25
-26
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+7
-4
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+18
-22
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
34520806
...
...
@@ -40,11 +40,14 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
constexpr
auto
a_grid_desc_ak0_m_ak1
=
gemm
.
MakeAGridDescriptor_AK0_M_AK1
(
to_ck_tensor
<
A
>
());
constexpr
auto
b_grid_desc_bk0_n_bk1
=
gemm
.
MakeBGridDescriptor_BK0_N_BK1
(
to_ck_tensor
<
B
>
());
constexpr
auto
c_grid_desc_m_n
=
gemm
.
MakeCGridDescriptor_M_N
(
to_ck_tensor
<
C
>
());
constexpr
auto
block_2_ctile_map
=
gemm
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
constexpr
auto
c_grid_desc_m_n
=
gemm
.
MakeCGridDescriptor_M_N
(
to_ck_tensor
<
C
>
());
constexpr
auto
block_2_ctile_map
=
gemm
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
using
GridwiseGemm
=
typename
G
::
template
GridwiseGemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_m_n
)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map));
using
GridwiseGemm
=
typename
G
::
template
GridwiseGemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_m_n
)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n, block_2_ctile_map));
constexpr
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
View file @
34520806
...
...
@@ -149,23 +149,21 @@ template <typename ALayout,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()
>
struct
CKDeviceGemm
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
struct
CKDeviceGemm
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
template
<
class
Descriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
Descriptor
&
a_grid_desc_mraw_kraw
)
template
<
class
Descriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
Descriptor
&
a_grid_desc_mraw_kraw
)
{
const
auto
MRaw
=
a_grid_desc_mraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
a_grid_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
...
...
@@ -253,14 +251,13 @@ struct CKDeviceGemm
}
}
template
<
class
Descriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
Descriptor
&
b_grid_desc_nraw_kraw
)
template
<
class
Descriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
Descriptor
&
b_grid_desc_nraw_kraw
)
{
const
auto
NRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
...
...
@@ -348,14 +345,13 @@ struct CKDeviceGemm
}
}
template
<
class
Descriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
Descriptor
&
c_grid_desc_mraw_nraw
)
template
<
class
Descriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
Descriptor
&
c_grid_desc_mraw_nraw
)
{
const
auto
MRaw
=
c_grid_desc_mraw_nraw
.
GetLength
(
I0
);
const
auto
NRaw
=
c_grid_desc_mraw_nraw
.
GetLength
(
I1
);
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
...
...
@@ -407,8 +403,8 @@ struct CKDeviceGemm
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
class
CGridDesc_M_N
>
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
class
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
...
...
@@ -416,7 +412,7 @@ struct CKDeviceGemm
c_grid_desc_m_n
);
}
template
<
class
AGridDesc_AK0_M_AK1
,
class
BGridDesc_BK0_N_BK1
,
class
CGridDesc_M_N
>
template
<
class
AGridDesc_AK0_M_AK1
,
class
BGridDesc_BK0_N_BK1
,
class
CGridDesc_M_N
>
using
GridwiseGemm
=
ck
::
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
...
...
@@ -461,7 +457,7 @@ struct CKDeviceGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
AElementwiseOperation
a_element_op
{};
BElementwiseOperation
b_element_op
{};
CElementwiseOperation
c_element_op
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment