Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c51b3d29
Commit
c51b3d29
authored
Oct 12, 2022
by
Paul
Browse files
Some more simplifications
parent
f8e5a547
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
104 additions
and
378 deletions
+104
-378
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+3
-11
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+60
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+27
-46
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+14
-321
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
c51b3d29
...
@@ -60,22 +60,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
...
@@ -60,22 +60,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
namespace migraphx {
namespace migraphx {
using gemm_t = CKDeviceGemm<${instance}, ${m}, ${k}, ${n}, ${sa}, ${sb}, ${sc}>;
using gemm_t = CKDeviceGemm<${instance}>;
constexpr __device__ gemm_t ckdg{};
using GridwiseGemm = decltype(ckdg.gridwisegemm);
extern "C" {
extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
{
make_tensors()(a_p, b_p, c_p)([&](auto a_t, auto b_t, auto c_t) {
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
constexpr ck::index_t shared_block_size =
ck_gemm<gemm_t>(a, b, c);
GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ char p_shared_block[shared_block_size];
make_tensors()(p_shared_block)([&](auto p_t) {
ck_gemm<gemm_t>(a_t, b_t, c_t, p_t);
});
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
0 → 100644
View file @
c51b3d29
#ifndef MIGRAPHX_GUARD_KERNELS_CK_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_HPP
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/tensor_description/tensor_descriptor.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
namespace
migraphx
{
namespace
detail
{
template
<
class
T
>
struct
to_ck_type_impl
{
using
type
=
T
;
};
template
<
>
struct
to_ck_type_impl
<
migraphx
::
half
>
{
using
type
=
ck
::
half_t
;
};
template
<
class
Shape
>
constexpr
bool
is_row_major
()
{
constexpr
auto
strides
=
Shape
{}.
strides
;
MIGRAPHX_ASSERT
(
strides
.
size
()
>=
2
);
if
(
strides
.
back
()
==
1
)
{
MIGRAPHX_ASSERT
(
not
Shape
{}.
is_trasnposed
());
return
true
;
}
MIGRAPHX_ASSERT
(
strides
[
strides
.
size
()
-
2
]
==
1
);
return
false
;
}
}
// namespace detail
template
<
class
T
>
using
to_ck_type
=
typename
detail
::
to_ck_type_impl
<
T
>::
type
;
template
<
class
Shape
>
using
to_ck_gemm_layout
=
conditional_t
<
detail
::
is_row_major
<
get_shape_c
<
Shape
>>
(),
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
template
<
class
Tensor
>
constexpr
auto
to_ck_tensor
()
{
constexpr
auto
s
=
get_shape_c
<
Tensor
>
{};
return
sequence
(
s
.
lens
.
size
(),
[](
auto
...
is
)
{
return
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
s
.
lens
[
is
]...),
ck
::
make_tuple
(
s
.
strides
[
is
]...));
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CK_HPP
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
c51b3d29
...
@@ -28,62 +28,43 @@
...
@@ -28,62 +28,43 @@
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp>
namespace
migraphx
{
namespace
migraphx
{
template
<
class
G
,
class
T
,
class
U
,
class
V
,
class
W
>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
W
&
p_t
)
{
constexpr
G
ckdg
{};
using
GridwiseGemm
=
decltype
(
ckdg
.
gridwisegemm
);
constexpr
auto
a_grid_desc_ak0_m_ak1
=
ckdg
.
MakeAGridDescriptor_AK0_M_AK1
();
template
<
class
G
,
class
A
,
class
B
,
class
C
>
constexpr
auto
b_grid_desc_bk0_n_bk1
=
ckdg
.
MakeBGridDescriptor_BK0_N_BK1
();
__device__
void
ck_gemm
(
const
A
&
a
,
const
B
&
b
,
const
C
&
c
)
constexpr
auto
c_grid_desc_m_n
=
ckdg
.
MakeCGridDescriptor_M_N
();
{
constexpr
auto
block_2_ctile_map
=
ckdg
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
constexpr
auto
a_desc
=
to_ck_tensor
<
A
>
();
constexpr
auto
b_desc
=
to_ck_tensor
<
B
>
();
constexpr
auto
c_desc
=
to_ck_tensor
<
C
>
();
constexpr
auto
block_2_ctile_map
=
G
::
MakeDefaultBlock2CTileMap
(
c_desc
);
// static_assert(GridwiseGemm::CheckValidity(
using
GridwiseGemm
=
typename
G
::
template
Make
<
a_desc
,
b_desc
,
c_desc
>;
//
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n
, block_2_ctile_map));
//
static_assert(GridwiseGemm::CheckValidity(a_desc, b_desc, c_desc
, block_2_ctile_map));
constexpr
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_
grid_desc_m_n
);
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_
desc
);
constexpr
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
constexpr
auto
shared_block_size
=
constexpr
auto
a_element_op
=
ckdg
.
a_element_op
;
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
auto
b_element_op
=
ckdg
.
b_element_op
;
__shared__
char
p_shared_block
[
shared_block_size
];
constexpr
auto
c_element_op
=
ckdg
.
c_element_op
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
constexpr
bool
HasMainKBlockLoop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
A
{}.
get_shape
().
elements
());
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a
.
data
(),
constexpr
bool
HasMainKBlockLoop
=
true
;
b
.
data
(),
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
c
.
data
(),
b_t
.
data
(),
p_shared_block
,
c_t
.
data
(),
G
::
AOp
(),
p_t
.
data
(),
G
::
BOp
(),
a_element_op
,
G
::
COp
(),
b_element_op
,
a_desc
,
c_element_op
,
b_desc
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
}
else
{
constexpr
bool
HasMainKBlockLoop
=
false
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
}
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
View file @
c51b3d29
...
@@ -39,29 +39,6 @@
...
@@ -39,29 +39,6 @@
namespace
migraphx
{
namespace
migraphx
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
static
constexpr
auto
I4
=
ck
::
Number
<
4
>
{};
static
constexpr
auto
I5
=
ck
::
Number
<
5
>
{};
static
constexpr
ck
::
index_t
K1
=
1
;
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
template
<
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
template
<
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
struct
BlockToCTileMap_M00_N0_M01Adapt
{
{
...
@@ -172,303 +149,12 @@ template <typename ALayout,
...
@@ -172,303 +149,12 @@ template <typename ALayout,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
MRaw
,
ck
::
index_t
KRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
struct
CKDeviceGemm
struct
CKDeviceGemm
{
{
// template<ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA>
template
<
class
AGridDesc_AK0_M_AK1
,
class
BGridDesc_BK0_N_BK1
,
class
CGridDesc_M_N
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
()
using
Make
=
{
ck
::
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
MRaw
,
KRaw
),
ck
::
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
MRaw
,
KRaw
),
ck
::
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
static_assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
static_assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_right_pad_transform
(
MRaw
,
MPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
static_assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
static_assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
// template<ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
NRaw
,
KRaw
),
ck
::
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
NRaw
,
KRaw
),
ck
::
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
static_assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
NRaw
,
NPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
static_assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
static_assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
NRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
static_assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
// template<ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC>
static
constexpr
auto
MakeCGridDescriptor_M_N
()
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
MRaw
,
NRaw
),
ck
::
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
MRaw
,
NRaw
),
ck
::
make_tuple
(
I1
,
StrideC
));
}
}();
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<8, 8, 8>());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<8, 8, 8>());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N<8, 8, 8>());
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
());
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
());
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
());
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
GridwiseGemm
=
ck
::
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -513,10 +199,17 @@ struct CKDeviceGemm
...
@@ -513,10 +199,17 @@ struct CKDeviceGemm
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
>
;
GridwiseGemm
gridwisegemm
{};
static
constexpr
auto
AOp
()
{
return
AElementwiseOperation
{};
}
AElementwiseOperation
a_element_op
{};
static
constexpr
auto
BOp
()
{
return
BElementwiseOperation
{};
}
BElementwiseOperation
b_element_op
{};
static
constexpr
auto
COp
()
{
return
CElementwiseOperation
{};
}
CElementwiseOperation
c_element_op
{};
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
class
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
};
};
}
// namespace migraphx
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment