Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1dd11890
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4272fff121b178cb2f4f0c743f1396acf8cb70fa"
Commit
1dd11890
authored
Sep 22, 2022
by
turneram
Browse files
Formatting
parent
07167910
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
308 additions
and
269 deletions
+308
-269
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+0
-1
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
.../gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
+31
-32
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+70
-55
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
...targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
+54
-41
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+108
-96
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
...gets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
+35
-36
test/verify/0ack_test_ck_gemm.cpp
test/verify/0ack_test_ck_gemm.cpp
+8
-6
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
1dd11890
...
@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
...
@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
)__migraphx__"
;
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
};
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
1dd11890
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
View file @
1dd11890
...
@@ -204,16 +204,15 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
...
@@ -204,16 +204,15 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
ck
::
Sequence
<
1
,
8
>
,
ck
::
Sequence
<
1
,
8
>
,
ck
::
Sequence
<
8
>>
;
ck
::
Sequence
<
8
>>
;
using
shapes_t
=
std
::
array
<
ck
::
index_t
,
3
>
;
using
shapes_t
=
std
::
array
<
ck
::
index_t
,
3
>
;
//shapes_t lengths_abc;
//
shapes_t lengths_abc;
//copy(c_lens.begin(), c_lens.end(), lengths_abc);
//
copy(c_lens.begin(), c_lens.end(), lengths_abc);
shapes_t
lengths_abc
=
{
c_lens
[
0
],
c_lens
[
1
],
c_lens
[
2
]};
shapes_t
lengths_abc
=
{
c_lens
[
0
],
c_lens
[
1
],
c_lens
[
2
]};
//constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
//
constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
constexpr
auto
strides_a
=
static_cast
<
shapes_t
>
(
a_strides
);
constexpr
auto
strides_a
=
static_cast
<
shapes_t
>
(
a_strides
);
constexpr
auto
strides_b
=
static_cast
<
shapes_t
>
(
b_strides
);
constexpr
auto
strides_b
=
static_cast
<
shapes_t
>
(
b_strides
);
constexpr
auto
strides_c
=
static_cast
<
shapes_t
>
(
c_strides
);
constexpr
auto
strides_c
=
static_cast
<
shapes_t
>
(
c_strides
);
std
::
array
<
const
void
*
,
2
>
input
=
{
a_t
.
data
(),
std
::
array
<
const
void
*
,
2
>
input
=
{
a_t
.
data
(),
b_t
.
data
()};
b_t
.
data
()};
std
::
array
<
void
*
,
1
>
output
=
{
c_t
.
data
()};
std
::
array
<
void
*
,
1
>
output
=
{
c_t
.
data
()};
auto
ck_add
=
DeviceElementwiseAddInstance
{};
auto
ck_add
=
DeviceElementwiseAddInstance
{};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
1dd11890
...
@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
{
{
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)));
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)));
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)),
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
}
}
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
;
DefaultBlock2CTileMap
block_2_ctile_map
;
DefaultBlock2CTileMap
block_2_ctile_map
;
if
(
true
or
GridwiseGemm
::
CheckValidity
(
if
(
true
or
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
{
{
//printf("Is valid\n");
//
printf("Is valid\n");
a_grid_desc_k0_m0_m1_k1
=
a_grid_desc_k0_m0_m1_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
b_grid_desc_k0_n0_n1_k1
=
b_grid_desc_k0_n0_n1_k1
=
...
@@ -83,20 +91,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -83,20 +91,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
}
}
else
else
{
{
//printf("Not valid\n");
//
printf("Not valid\n");
}
}
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
{
{
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I2
)));
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I2
)));
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
)),
printf
(
"c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}
\n
"
,
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
)),
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I1
)));
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}
\n
"
,
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
)),
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I1
)));
}
}
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
View file @
1dd11890
...
@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
constexpr
auto
a_grid_desc_ak0_m_ak1
=
tp
.
MakeAGridDescriptor_AK0_M_AK1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
constexpr
auto
a_grid_desc_ak0_m_ak1
=
tp
.
MakeAGridDescriptor_AK0_M_AK1
(
constexpr
auto
b_grid_desc_bk0_n_bk1
=
tp
.
MakeBGridDescriptor_BK0_N_BK1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
constexpr
auto
c_grid_desc_m_n
=
tp
.
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
constexpr
auto
b_grid_desc_bk0_n_bk1
=
tp
.
MakeBGridDescriptor_BK0_N_BK1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
constexpr
auto
c_grid_desc_m_n
=
tp
.
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
/* constexpr */
auto
block_2_ctile_map
=
tp
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
/* constexpr */
auto
block_2_ctile_map
=
tp
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
{
{
printf
(
"a_grid_desc_ak0_m_ak1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)));
printf
(
"a_grid_desc_ak0_m_ak1{%i, %i, %i}
\n
"
,
printf
(
"b_grid_desc_bk0_n_bk1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I2
)));
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)),
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_bk0_n_bk1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
}
}
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
{};
c_grid_desc_mblock_mperblock_nblock_nperblock
{};
if
(
true
or
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
if
(
true
or
b_grid_desc_bk0_n_bk1
,
GridwiseGemm
::
CheckValidity
(
c_grid_desc_m_n
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
))
block_2_ctile_map
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
// if(idx.global == 0)
// if(idx.global == 0)
// {
// {
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2)));
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n",
// printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)),
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// int(a_grid_desc_k0_m0_m1_k1.GetLength(I2))); printf("b_grid_desc_k0_n0_n1_k1{%i, %i,
// %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)),
// int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n",
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)),
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// }
// }
const
auto
K
=
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
auto
a_element_op
=
tp
.
a_element_op
;
auto
a_element_op
=
tp
.
a_element_op
;
auto
b_element_op
=
tp
.
b_element_op
;
auto
b_element_op
=
tp
.
b_element_op
;
auto
c_element_op
=
tp
.
c_element_op
;
auto
c_element_op
=
tp
.
c_element_op
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
View file @
1dd11890
...
@@ -126,7 +126,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
...
@@ -126,7 +126,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{
{
}
}
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
...
@@ -166,7 +167,10 @@ struct BlockToCTileMap_M00_N0_M01Adapt
...
@@ -166,7 +167,10 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
__host__
__device__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
__device__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
private:
ck
::
index_t
M01_
;
ck
::
index_t
M01_
;
...
@@ -217,7 +221,8 @@ template <typename ALayout,
...
@@ -217,7 +221,8 @@ template <typename ALayout,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
struct
TuningParams
struct
TuningParams
{
{
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
ck
::
index_t
MRaw
,
ck
::
index_t
KRaw
,
ck
::
index_t
StrideA
)
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
ck
::
index_t
MRaw
,
ck
::
index_t
KRaw
,
ck
::
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
ck
::
is_same_v
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
@@ -242,19 +247,19 @@ struct TuningParams
...
@@ -242,19 +247,19 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
// pad both M and K
// pad both M and K
//assert(K % AK1 == 0);
//
assert(K % AK1 == 0);
const
auto
AK0
=
K
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -266,12 +271,12 @@ struct TuningParams
...
@@ -266,12 +271,12 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
// pad M, but not K
// pad M, but not K
//assert(KRaw % AK1 == 0);
//
assert(KRaw % AK1 == 0);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_right_pad_transform
(
MRaw
,
MPad
)),
ck
::
make_right_pad_transform
(
MRaw
,
MPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -283,18 +288,19 @@ struct TuningParams
...
@@ -283,18 +288,19 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
{
// pad K, but not M
// pad K, but not M
//assert(K % AK1 == 0);
//
assert(K % AK1 == 0);
const
auto
AK0
=
K
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -305,12 +311,12 @@ struct TuningParams
...
@@ -305,12 +311,12 @@ struct TuningParams
else
else
{
{
// not pad M or K
// not pad M or K
//assert(KRaw % AK1 == 0);
//
assert(KRaw % AK1 == 0);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -320,7 +326,8 @@ struct TuningParams
...
@@ -320,7 +326,8 @@ struct TuningParams
}
}
}
}
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
ck
::
index_t
KRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideB
)
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
ck
::
index_t
KRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideB
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
@@ -345,19 +352,19 @@ struct TuningParams
...
@@ -345,19 +352,19 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
// pad both N and K
// pad both N and K
//assert(K % BK1 == 0);
//
assert(K % BK1 == 0);
const
auto
BK0
=
K
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
NRaw
,
NPad
),
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
NRaw
,
NPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
b_grid_desc_n_k
,
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -369,12 +376,12 @@ struct TuningParams
...
@@ -369,12 +376,12 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
// pad N, but not K
// pad N, but not K
//assert(KRaw % BK1 == 0);
//
assert(KRaw % BK1 == 0);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -386,18 +393,19 @@ struct TuningParams
...
@@ -386,18 +393,19 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
{
// pad K, but not N
// pad K, but not N
//assert(K % BK1 == 0);
//
assert(K % BK1 == 0);
const
auto
BK0
=
K
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
NRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
NRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
b_grid_desc_n_k
,
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -408,12 +416,12 @@ struct TuningParams
...
@@ -408,12 +416,12 @@ struct TuningParams
else
else
{
{
// not pad N or K
// not pad N or K
//assert(KRaw % BK1 == 0);
//
assert(KRaw % BK1 == 0);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
@@ -423,7 +431,8 @@ struct TuningParams
...
@@ -423,7 +431,8 @@ struct TuningParams
}
}
}
}
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
MRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideC
)
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
MRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideC
)
{
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
...
@@ -448,7 +457,8 @@ struct TuningParams
...
@@ -448,7 +457,8 @@ struct TuningParams
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
{
// pad M and N
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
...
@@ -460,7 +470,8 @@ struct TuningParams
...
@@ -460,7 +470,8 @@ struct TuningParams
// pad M, but not N
// pad M, but not N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
...
@@ -470,7 +481,8 @@ struct TuningParams
...
@@ -470,7 +481,8 @@ struct TuningParams
// pad N, but not M
// pad N, but not M
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
...
@@ -544,7 +556,7 @@ struct TuningParams
...
@@ -544,7 +556,7 @@ struct TuningParams
};
};
using
gemm
=
TuningParams
using
gemm
=
TuningParams
// clang-format off
// clang-format off
//| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
View file @
1dd11890
...
@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
...
@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Row
;
//Col;
using
ALayout
=
Row
;
//
Col;
using
BLayout
=
Row
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
...
@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 =
...
@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 =
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
test/verify/0ack_test_ck_gemm.cpp
View file @
1dd11890
...
@@ -37,8 +37,10 @@
...
@@ -37,8 +37,10 @@
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
...
@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
...
@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
2
,
3
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
3
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
3
,
4
}};
std
::
vector
<
float
>
v1
(
2
*
3
,
1
);
std
::
vector
<
float
>
v1
(
2
*
3
,
1
);
std
::
iota
(
v1
.
begin
(),
v1
.
end
(),
1
);
std
::
iota
(
v1
.
begin
(),
v1
.
end
(),
1
);
std
::
vector
<
float
>
v2
(
3
*
4
,
1
);
std
::
vector
<
float
>
v2
(
3
*
4
,
1
);
//std::iota(v2.begin(), v2.end(), 1);
//
std::iota(v2.begin(), v2.end(), 1);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
v2
});
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
v2
});
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
//l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
//
l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_gemm"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_gemm"
),
l1
,
l2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment