Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2ca29096
"docs/source/Tutorial/QuickStart.rst" did not exist on "d1b1e7b311380d9ab0454c85c851d8159aa5b67e"
Commit
2ca29096
authored
Oct 16, 2022
by
Paul
Browse files
Refactor to use correct descriptors
parent
6fda1d3e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
287 additions
and
33 deletions
+287
-33
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+1
-5
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+14
-12
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+272
-16
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
2ca29096
...
@@ -56,18 +56,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
...
@@ -56,18 +56,14 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
namespace migraphx {
using gemm_t = CKDeviceGemm<${instance}>;
extern "C" {
extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
{
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
ck_gemm<
gemm_t
>(a, b, c);
ck_gemm<
CKDeviceGemm<${instance}>
>(a, b, c);
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
2ca29096
...
@@ -36,16 +36,18 @@ namespace migraphx {
...
@@ -36,16 +36,18 @@ namespace migraphx {
template
<
class
G
,
class
A
,
class
B
,
class
C
>
template
<
class
G
,
class
A
,
class
B
,
class
C
>
__device__
void
ck_gemm
(
const
A
&
a
,
const
B
&
b
,
const
C
&
c
)
__device__
void
ck_gemm
(
const
A
&
a
,
const
B
&
b
,
const
C
&
c
)
{
{
constexpr
auto
a_desc
=
to_ck_tensor
<
A
>
();
constexpr
G
gemm
{};
constexpr
auto
b_desc
=
to_ck_tensor
<
B
>
();
constexpr
auto
c_desc
=
to_ck_tensor
<
C
>
();
constexpr
auto
block_2_ctile_map
=
G
::
MakeDefaultBlock2CTileMap
(
c_desc
);
using
GridwiseGemm
=
typename
G
::
template
Make
<
decltype
(
a_desc
),
decltype
(
b_desc
),
decltype
(
c_desc
)>;
constexpr
auto
a_grid_desc_ak0_m_ak1
=
gemm
.
MakeAGridDescriptor_AK0_M_AK1
(
to_ck_tensor
<
A
>
());
// static_assert(GridwiseGemm::CheckValidity(a_desc, b_desc, c_desc, block_2_ctile_map));
constexpr
auto
b_grid_desc_bk0_n_bk1
=
gemm
.
MakeBGridDescriptor_BK0_N_BK1
(
to_ck_tensor
<
B
>
());
constexpr
auto
c_grid_desc_m_n
=
gemm
.
MakeCGridDescriptor_M_N
(
to_ck_tensor
<
C
>
());
constexpr
auto
block_2_ctile_map
=
gemm
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
using
GridwiseGemm
=
typename
G
::
template
GridwiseGemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_m_n
)>;
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map));
constexpr
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_
desc
);
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_
grid_desc_m_n
);
constexpr
auto
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
auto
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
char
p_shared_block
[
shared_block_size
];
__shared__
char
p_shared_block
[
shared_block_size
];
...
@@ -56,11 +58,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
...
@@ -56,11 +58,11 @@ __device__ void ck_gemm(const A& a, const B& b, const C& c)
b
.
data
(),
b
.
data
(),
c
.
data
(),
c
.
data
(),
p_shared_block
,
p_shared_block
,
G
::
AOp
()
,
gemm
.
a_element_op
,
G
::
BOp
()
,
gemm
.
b_element_op
,
G
::
COp
()
,
gemm
.
c_element_op
,
a_
desc
,
a_
grid_desc_ak0_m_ak1
,
b_
desc
,
b_
grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
View file @
2ca29096
...
@@ -149,11 +149,275 @@ template <typename ALayout,
...
@@ -149,11 +149,275 @@ template <typename ALayout,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()
struct
CKDeviceGemm
>
struct
CKDeviceGemm
{
{
template
<
class
AGridDesc_AK0_M_AK1
,
class
BGridDesc_BK0_N_BK1
,
class
CGridDesc_M_N
>
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
using
Make
=
ck
::
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
template
<
class
Descriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
Descriptor
&
a_grid_desc_mraw_kraw
)
{
const
auto
MRaw
=
a_grid_desc_mraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
a_grid_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_right_pad_transform
(
MRaw
,
MPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
template
<
class
Descriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
Descriptor
&
b_grid_desc_nraw_kraw
)
{
const
auto
NRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
ck
::
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
NRaw
,
NPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
NRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
class
Descriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
Descriptor
&
c_grid_desc_mraw_nraw
)
{
const
auto
MRaw
=
c_grid_desc_mraw_nraw
.
GetLength
(
I0
);
const
auto
NRaw
=
c_grid_desc_mraw_nraw
.
GetLength
(
I1
);
const
auto
M
=
ck
::
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
ck
::
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<8, 8, 8>());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<8, 8, 8>());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N<8, 8, 8>());
// using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1());
// using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
// using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
class
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
template
<
class
AGridDesc_AK0_M_AK1
,
class
BGridDesc_BK0_N_BK1
,
class
CGridDesc_M_N
>
using
GridwiseGemm
=
ck
::
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -197,18 +461,10 @@ struct CKDeviceGemm
...
@@ -197,18 +461,10 @@ struct CKDeviceGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
>
;
static
constexpr
auto
AOp
()
{
return
AElementwiseOperation
{};
}
AElementwiseOperation
a_element_op
{};
static
constexpr
auto
BOp
()
{
return
BElementwiseOperation
{};
}
BElementwiseOperation
b_element_op
{};
static
constexpr
auto
COp
()
{
return
CElementwiseOperation
{};
}
CElementwiseOperation
c_element_op
{};
// return block_id to C matrix tile idx (m0, n0) mapping
template
<
class
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
};
};
}
// namespace migraphx
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment