Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d1e27426
Commit
d1e27426
authored
Sep 13, 2022
by
turneram
Browse files
Move ck includes to own header file
parent
6fb1706a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
337 additions
and
15 deletions
+337
-15
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+112
-10
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+7
-3
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
...gets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
+216
-0
test/verify/0a_test_ck_gemm.cpp
test/verify/0a_test_ck_gemm.cpp
+2
-2
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
d1e27426
...
@@ -40,8 +40,43 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -40,8 +40,43 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
// NOLINTNEXTLINE
// NOLINTNEXTLINE
// static const char* const ck_gemm_kernel = R"__migraphx__(
// #include <migraphx/kernels/ck_gemm.hpp>
// #include <migraphx/kernels/ops.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/generic_constant.hpp>
// #include <args.hpp>
// #include <hip/hip_runtime_api.h>
// namespace migraphx {
// extern "C" {
// __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
// {
// // hipDeviceProp_t hdp{};
// // printf("Shared mem: %i\n", int(hdp.sharedMemPerBlock));
// // make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// // ck_gemm(xs...);
// // });
// make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
// __shared__ float p_shared_block[512]; //[(a_t.get_shape().elements() + b_t.get_shape().elements()) * 2];
// ck_gemm(a_t, b_t, c_t, p_shared_block);
// // make_tensors()(p_shared_block)([&](auto p_t) {
// // ck_gemm(a_t, b_t, c_t, p_t);
// // });
// });
// }
// }
// } // namespace migraphx
// )__migraphx__";
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_
gemm
.hpp>
#include <migraphx/kernels/ck_
includes
.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
...
@@ -55,16 +90,83 @@ extern "C" {
...
@@ -55,16 +90,83 @@ extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
{
hipDeviceProp_t hdp{};
printf("Shared mem: %i\n", int(hdp.sharedMemPerBlock));
// make_tensors()(a_p, b_p, c_p)([](auto&&... xs) {
// ck_gemm(xs...);
// });
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
__shared__ void* p_shared_block[(a_t.get_shape().elements() + b_t.get_shape().elements()) * 2];
constexpr auto alens = get_shape_c<decltype(a_t)>{}.lens;
make_tensors()(p_shared_block)([&](auto p_t) {
constexpr auto m = alens[0];
ck_gemm(a_t, b_t, c_t, p_t);
constexpr auto k = alens[1];
});
constexpr auto blens = get_shape_c<decltype(b_t)>{}.lens;
constexpr auto n = blens[1];
constexpr auto astrides = get_shape_c<decltype(a_t)>{}.strides;
constexpr auto as = astrides[0];
constexpr auto bstrides = get_shape_c<decltype(b_t)>{}.strides;
constexpr auto bs = bstrides[0];
constexpr auto cstrides = get_shape_c<decltype(c_t)>{}.strides;
constexpr auto cs = cstrides[0];
auto a_grid_desc_k0_m_k1 = MakeAGridDescriptor_K0_M_K1(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
auto b_grid_desc_k0_n_k1 = MakeBGridDescriptor_K0_N_K1(
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
auto a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
auto b_grid_desc_k0_n0_n1_k1 =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
auto c_grid_desc_m0_m10_m11_n0_n10_n11 =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n);
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true;
constexpr ck::index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(float);
__shared__ float p_shared_block[shared_block_size];
GridwiseGemm::Run(a_t.data(),
b_t.data(),
c_t.data(),
p_shared_block,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
d1e27426
...
@@ -212,8 +212,10 @@ using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
...
@@ -212,8 +212,10 @@ using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
template
<
class
T
,
class
U
,
class
V
,
class
W
>
// template <class T, class U, class V, class W>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
const
W
&
p_t
)
// __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
float
*
p_t
)
{
{
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
m
=
alens
[
0
];
...
@@ -281,10 +283,12 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -281,10 +283,12 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
auto
num_bytes
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
printf
(
"Bytes: %i
\n
"
,
int
(
num_bytes
));
GridwiseGemm
::
Run
(
a_t
.
data
(),
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
/*
p_t.data(),
*/
p_t
,
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
0 → 100644
View file @
d1e27426
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace
migraphx
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
static
constexpr
auto
I4
=
ck
::
Number
<
4
>
{};
static
constexpr
auto
I5
=
ck
::
Number
<
5
>
{};
static
constexpr
ck
::
index_t
K1
=
1
;
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// Values hard-coded by CK
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
BlockSize
=
256
;
static
constexpr
ck
::
index_t
K0PerBlock
=
16
;
static
constexpr
ck
::
index_t
M1PerThread
=
4
;
static
constexpr
ck
::
index_t
N1PerThread
=
4
;
static
constexpr
ck
::
index_t
KPerThread
=
1
;
using
M1N1ThreadClusterM1Xs
=
S
<
8
,
2
>
;
using
M1N1ThreadClusterN1Xs
=
S
<
8
,
2
>
;
using
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
ABlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
ABlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
=
S
<
2
,
1
,
4
,
1
>
;
using
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
=
S
<
8
,
1
,
32
,
1
>
;
using
BBlockTransferThreadClusterArrangeOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcAccessOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
BBlockTransferSrcVectorTensorContiguousDimOrder
=
S
<
0
,
3
,
1
,
2
>
;
using
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
=
S
<
1
,
1
,
4
,
1
>
;
using
CThreadTransferSrcDstAccessOrder
=
S
<
0
,
1
,
2
,
3
,
4
,
5
>
;
static
constexpr
ck
::
index_t
CThreadTransferSrcDstVectorDim
=
5
;
static
constexpr
ck
::
index_t
CThreadTransferDstScalarPerVector
=
4
;
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
ck
::
index_t
M
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
ck
::
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
K
),
ck
::
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
K
),
ck
::
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
M
,
PadM
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
ck
::
index_t
K
,
ck
::
index_t
N
,
ck
::
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
ck
::
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
K
,
N
),
ck
::
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
K
,
N
),
ck
::
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
K0
,
K1Number
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
N
),
ck
::
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
M
,
N
),
ck
::
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
PadM
),
ck
::
make_right_pad_transform
(
N
,
PadN
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
M
),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
}
// namespace migraphx
#endif
test/verify/0a_test_ck_gemm.cpp
View file @
d1e27426
...
@@ -33,8 +33,8 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
...
@@ -33,8 +33,8 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
128
,
256
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment