Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1dd11890
Commit
1dd11890
authored
Sep 22, 2022
by
turneram
Browse files
Formatting
parent
07167910
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
308 additions
and
269 deletions
+308
-269
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+0
-1
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
.../gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
+31
-32
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+70
-55
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
...targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
+54
-41
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+108
-96
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
...gets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
+35
-36
test/verify/0ack_test_ck_gemm.cpp
test/verify/0ack_test_ck_gemm.cpp
+8
-6
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
1dd11890
...
...
@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
};
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
1dd11890
...
...
@@ -213,8 +213,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
auto
op
=
Add
{};
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
b_desc
,
c_desc
,
op
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
View file @
1dd11890
...
...
@@ -164,12 +164,12 @@ struct Div
};
};
using
InDataTypeTuple
=
ck
::
Tuple
<
ABDataType
,
ABDataType
>
;
using
OutDataTypeTuple
=
ck
::
Tuple
<
CDataType
>
;
using
ElementwiseOperation
=
Add
;
using
InDataTypeTuple
=
ck
::
Tuple
<
ABDataType
,
ABDataType
>
;
using
OutDataTypeTuple
=
ck
::
Tuple
<
CDataType
>
;
using
ElementwiseOperation
=
Add
;
static
constexpr
auto
MPerThread
=
8
;
using
InScalarPerVectorSeq
=
ck
::
Sequence
<
1
,
8
>
;
using
OutScalarPerVectorSeq
=
ck
::
Sequence
<
8
>
;
using
InScalarPerVectorSeq
=
ck
::
Sequence
<
1
,
8
>
;
using
OutScalarPerVectorSeq
=
ck
::
Sequence
<
8
>
;
// using DeviceElementwiseAddInstance =
// ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
...
...
@@ -186,7 +186,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
// auto idx = make_index();
constexpr
auto
a_lens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
ck
::
index_t
ndim
=
a_lens
.
size
();
constexpr
ck
::
index_t
ndim
=
a_lens
.
size
();
constexpr
auto
b_lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
b_strides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
ck
::
index_t
b_ndim
=
b_lens
.
size
();
...
...
@@ -197,47 +197,46 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
ABDataType
,
ABDataType
>
,
ck
::
Tuple
<
CDataType
>
,
Add
,
ndim
,
8
,
ck
::
Sequence
<
1
,
8
>
,
ck
::
Sequence
<
8
>>
;
ck
::
Tuple
<
CDataType
>
,
Add
,
ndim
,
8
,
ck
::
Sequence
<
1
,
8
>
,
ck
::
Sequence
<
8
>>
;
using
shapes_t
=
std
::
array
<
ck
::
index_t
,
3
>
;
//shapes_t lengths_abc;
//copy(c_lens.begin(), c_lens.end(), lengths_abc);
//
shapes_t lengths_abc;
//
copy(c_lens.begin(), c_lens.end(), lengths_abc);
shapes_t
lengths_abc
=
{
c_lens
[
0
],
c_lens
[
1
],
c_lens
[
2
]};
//constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
//
constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
constexpr
auto
strides_a
=
static_cast
<
shapes_t
>
(
a_strides
);
constexpr
auto
strides_b
=
static_cast
<
shapes_t
>
(
b_strides
);
constexpr
auto
strides_c
=
static_cast
<
shapes_t
>
(
c_strides
);
std
::
array
<
const
void
*
,
2
>
input
=
{
a_t
.
data
(),
b_t
.
data
()};
std
::
array
<
const
void
*
,
2
>
input
=
{
a_t
.
data
(),
b_t
.
data
()};
std
::
array
<
void
*
,
1
>
output
=
{
c_t
.
data
()};
auto
ck_add
=
DeviceElementwiseAddInstance
{};
auto
argument
=
ck_add
.
MakeArgumentPointer
(
auto
ck_add
=
DeviceElementwiseAddInstance
{};
auto
argument
=
ck_add
.
MakeArgumentPointer
(
lengths_abc
,
{
strides_a
,
strides_b
},
{
strides_c
},
input
,
output
,
Add
{});
using
InGrid1dDescTuple
=
decltype
(
ck_add
.
GenerateInOutGrid1dDescTuple
(
ck
::
Number
<
ndim
>
{}));
using
OutGrid1dDescTuple
=
decltype
(
ck_add
.
GenerateInOutGrid1dDescTuple
(
ck
::
Number
<
ndim
>
{}));
using
InDataTypePointerTuple
=
decltype
(
ck_add
.
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
ck_add
.
GenerateOutDataTypePointerTuple
());
using
GridwiseElementwise
=
ck
::
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
MPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
using
GridwiseElementwise
=
ck
::
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
MPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
GridwiseElementwise
::
Run
(
argument
.
in_grid_1d_desc_tuple_
,
argument
.
out_grid_1d_desc_tuple_
,
argument
.
in_dev_buffers_
,
argument
.
out_dev_buffers_
,
argument
.
elementwise_op_
);
argument
.
out_grid_1d_desc_tuple_
,
argument
.
in_dev_buffers_
,
argument
.
out_dev_buffers_
,
argument
.
elementwise_op_
);
}
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
1dd11890
...
...
@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if
(
idx
.
global
==
0
)
{
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
}
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
;
DefaultBlock2CTileMap
block_2_ctile_map
;
if
(
true
or
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
if
(
true
or
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
{
//printf("Is valid\n");
//
printf("Is valid\n");
a_grid_desc_k0_m0_m1_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
b_grid_desc_k0_n0_n1_k1
=
...
...
@@ -83,79 +91,86 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
}
else
{
//printf("Not valid\n");
//
printf("Not valid\n");
}
if
(
idx
.
global
==
0
)
{
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}
\n
"
,
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
)),
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I1
)));
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}
\n
"
,
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
)),
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I1
)));
}
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
false
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
else
else
{
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasDoubleTailKBlockLoop
=
false
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
View file @
1dd11890
...
...
@@ -37,7 +37,7 @@ template <class T, class U, class V, class W>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
const
W
&
p_t
)
{
static
gemm
tp
{};
using
GridwiseGemm
=
decltype
(
tp
.
gg
);
using
GridwiseGemm
=
decltype
(
tp
.
gg
);
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
k
=
alens
[
1
];
...
...
@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if
(
idx
.
global
==
0
)
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
constexpr
auto
a_grid_desc_ak0_m_ak1
=
tp
.
MakeAGridDescriptor_AK0_M_AK1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
constexpr
auto
b_grid_desc_bk0_n_bk1
=
tp
.
MakeBGridDescriptor_BK0_N_BK1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
constexpr
auto
c_grid_desc_m_n
=
tp
.
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
constexpr
auto
a_grid_desc_ak0_m_ak1
=
tp
.
MakeAGridDescriptor_AK0_M_AK1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
constexpr
auto
b_grid_desc_bk0_n_bk1
=
tp
.
MakeBGridDescriptor_BK0_N_BK1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
constexpr
auto
c_grid_desc_m_n
=
tp
.
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
/* constexpr */
auto
block_2_ctile_map
=
tp
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
if
(
idx
.
global
==
0
)
{
printf
(
"a_grid_desc_ak0_m_ak1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_bk0_n_bk1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
printf
(
"a_grid_desc_ak0_m_ak1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_bk0_n_bk1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
}
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
{};
if
(
true
or
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
))
c_grid_desc_mblock_mperblock_nblock_nperblock
{};
if
(
true
or
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
}
// if(idx.global == 0)
// {
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2)));
// printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n",
// int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)),
// int(a_grid_desc_k0_m0_m1_k1.GetLength(I2))); printf("b_grid_desc_k0_n0_n1_k1{%i, %i,
// %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)),
// int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n",
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)),
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// }
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
auto
a_element_op
=
tp
.
a_element_op
;
auto
b_element_op
=
tp
.
b_element_op
;
auto
c_element_op
=
tp
.
c_element_op
;
...
...
@@ -93,31 +106,31 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
{
constexpr
bool
HasMainKBlockLoop
=
true
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
}
else
else
{
constexpr
bool
HasMainKBlockLoop
=
false
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
b_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
View file @
1dd11890
...
...
@@ -51,8 +51,8 @@ static constexpr auto I5 = ck::Number<5>{};
static
constexpr
ck
::
index_t
K1
=
1
;
static
constexpr
auto
K1Number
=
ck
::
Number
<
K1
>
{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// using ALayout = Row;
// using BLayout = Row;
// using CLayout = Row;
...
...
@@ -86,7 +86,7 @@ using S = ck::Sequence<Is...>;
// static constexpr ck::index_t NPerBlock = 128;
// static constexpr ck::index_t KPerBlock = 32;
// static constexpr ck::index_t AK1 = 8;
// static constexpr ck::index_t BK1 = 2;
// static constexpr ck::index_t BK1 = 2;
// static constexpr ck::index_t MPerXDL = 32;
// static constexpr ck::index_t NPerXDL = 32;
// static constexpr ck::index_t MXdlPerWave = 4;
...
...
@@ -126,7 +126,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{
}
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
...
...
@@ -156,7 +157,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt
ck
::
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
ck
::
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
...
@@ -166,7 +167,10 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
__device__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
__device__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
ck
::
index_t
M01_
;
...
...
@@ -217,7 +221,8 @@ template <typename ALayout,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
struct
TuningParams
{
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
ck
::
index_t
MRaw
,
ck
::
index_t
KRaw
,
ck
::
index_t
StrideA
)
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
ck
::
index_t
MRaw
,
ck
::
index_t
KRaw
,
ck
::
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -239,88 +244,90 @@ struct TuningParams
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
//assert(K % AK1 == 0);
//
assert(K % AK1 == 0);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
M
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
//assert(KRaw % AK1 == 0);
//
assert(KRaw % AK1 == 0);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_right_pad_transform
(
MRaw
,
MPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_right_pad_transform
(
MRaw
,
MPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
//assert(K % AK1 == 0);
//
assert(K % AK1 == 0);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
//assert(KRaw % AK1 == 0);
//
assert(KRaw % AK1 == 0);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
MRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
ck
::
index_t
KRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideB
)
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
ck
::
index_t
KRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -342,88 +349,90 @@ struct TuningParams
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
//assert(K % BK1 == 0);
//
assert(K % BK1 == 0);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
NRaw
,
NPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
NRaw
,
NPad
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
N
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
//assert(KRaw % BK1 == 0);
//
assert(KRaw % BK1 == 0);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
//assert(K % BK1 == 0);
//
assert(K % BK1 == 0);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
NRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
NRaw
),
ck
::
make_right_pad_transform
(
KRaw
,
KPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
//assert(KRaw % BK1 == 0);
//
assert(KRaw % BK1 == 0);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
MRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideC
)
static
constexpr
auto
MakeCGridDescriptor_M_N
(
ck
::
index_t
MRaw
,
ck
::
index_t
NRaw
,
ck
::
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
...
...
@@ -445,32 +454,35 @@ struct TuningParams
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
make_right_pad_transform
(
MRaw
,
MPad
),
ck
::
make_pass_through_transform
(
NRaw
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
GemmSpec
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
MRaw
),
ck
::
make_right_pad_transform
(
NRaw
,
NPad
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
...
...
@@ -544,7 +556,7 @@ struct TuningParams
};
using
gemm
=
TuningParams
// clang-format off
// clang-format off
//| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
View file @
1dd11890
...
...
@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Row
;
//Col;
using
ALayout
=
Row
;
//
Col;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
...
...
@@ -216,39 +216,39 @@ using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
GridwiseGemm
=
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 =
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
}
// namespace migraphx
#endif
test/verify/0ack_test_ck_gemm.cpp
View file @
1dd11890
...
...
@@ -37,8 +37,10 @@
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
...
...
@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
3
,
4
}};
std
::
vector
<
float
>
v1
(
2
*
3
,
1
);
std
::
vector
<
float
>
v1
(
2
*
3
,
1
);
std
::
iota
(
v1
.
begin
(),
v1
.
end
(),
1
);
std
::
vector
<
float
>
v2
(
3
*
4
,
1
);
//std::iota(v2.begin(), v2.end(), 1);
std
::
vector
<
float
>
v2
(
3
*
4
,
1
);
//
std::iota(v2.begin(), v2.end(), 1);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
v2
});
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
//l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
//
l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_gemm"
),
l1
,
l2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment