Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1dd11890
"docs/vscode:/vscode.git/clone" did not exist on "cbf281f0c51b84f3cb4a561b68993651493034e9"
Commit
1dd11890
authored
Sep 22, 2022
by
turneram
Browse files
Formatting
parent
07167910
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
308 additions
and
269 deletions
+308
-269
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+0
-1
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
.../gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
+31
-32
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+70
-55
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
...targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
+54
-41
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
...gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
+108
-96
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
...gets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
+35
-36
test/verify/0ack_test_ck_gemm.cpp
test/verify/0ack_test_ck_gemm.cpp
+8
-6
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
1dd11890
...
@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
...
@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
)__migraphx__"
;
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
};
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
1dd11890
...
@@ -213,8 +213,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
...
@@ -213,8 +213,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
AScalarPerVector
,
AScalarPerVector
,
BScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
CScalarPerVector
>
;
auto
op
=
Add
{};
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
b_desc
,
c_desc
,
op
);
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
b_desc
,
c_desc
,
op
);
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise2.hpp
View file @
1dd11890
...
@@ -164,12 +164,12 @@ struct Div
...
@@ -164,12 +164,12 @@ struct Div
};
};
};
};
using
InDataTypeTuple
=
ck
::
Tuple
<
ABDataType
,
ABDataType
>
;
using
InDataTypeTuple
=
ck
::
Tuple
<
ABDataType
,
ABDataType
>
;
using
OutDataTypeTuple
=
ck
::
Tuple
<
CDataType
>
;
using
OutDataTypeTuple
=
ck
::
Tuple
<
CDataType
>
;
using
ElementwiseOperation
=
Add
;
using
ElementwiseOperation
=
Add
;
static
constexpr
auto
MPerThread
=
8
;
static
constexpr
auto
MPerThread
=
8
;
using
InScalarPerVectorSeq
=
ck
::
Sequence
<
1
,
8
>
;
using
InScalarPerVectorSeq
=
ck
::
Sequence
<
1
,
8
>
;
using
OutScalarPerVectorSeq
=
ck
::
Sequence
<
8
>
;
using
OutScalarPerVectorSeq
=
ck
::
Sequence
<
8
>
;
// using DeviceElementwiseAddInstance =
// using DeviceElementwiseAddInstance =
// ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
// ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
...
@@ -186,7 +186,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
...
@@ -186,7 +186,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
// auto idx = make_index();
// auto idx = make_index();
constexpr
auto
a_lens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_lens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
auto
a_strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
ck
::
index_t
ndim
=
a_lens
.
size
();
constexpr
ck
::
index_t
ndim
=
a_lens
.
size
();
constexpr
auto
b_lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
b_lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
b_strides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
auto
b_strides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
ck
::
index_t
b_ndim
=
b_lens
.
size
();
constexpr
ck
::
index_t
b_ndim
=
b_lens
.
size
();
...
@@ -197,47 +197,46 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
...
@@ -197,47 +197,46 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
using
DeviceElementwiseAddInstance
=
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
ABDataType
,
ABDataType
>
,
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
ABDataType
,
ABDataType
>
,
ck
::
Tuple
<
CDataType
>
,
ck
::
Tuple
<
CDataType
>
,
Add
,
Add
,
ndim
,
ndim
,
8
,
8
,
ck
::
Sequence
<
1
,
8
>
,
ck
::
Sequence
<
1
,
8
>
,
ck
::
Sequence
<
8
>>
;
ck
::
Sequence
<
8
>>
;
using
shapes_t
=
std
::
array
<
ck
::
index_t
,
3
>
;
using
shapes_t
=
std
::
array
<
ck
::
index_t
,
3
>
;
//shapes_t lengths_abc;
//
shapes_t lengths_abc;
//copy(c_lens.begin(), c_lens.end(), lengths_abc);
//
copy(c_lens.begin(), c_lens.end(), lengths_abc);
shapes_t
lengths_abc
=
{
c_lens
[
0
],
c_lens
[
1
],
c_lens
[
2
]};
shapes_t
lengths_abc
=
{
c_lens
[
0
],
c_lens
[
1
],
c_lens
[
2
]};
//constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
//
constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
constexpr
auto
strides_a
=
static_cast
<
shapes_t
>
(
a_strides
);
constexpr
auto
strides_a
=
static_cast
<
shapes_t
>
(
a_strides
);
constexpr
auto
strides_b
=
static_cast
<
shapes_t
>
(
b_strides
);
constexpr
auto
strides_b
=
static_cast
<
shapes_t
>
(
b_strides
);
constexpr
auto
strides_c
=
static_cast
<
shapes_t
>
(
c_strides
);
constexpr
auto
strides_c
=
static_cast
<
shapes_t
>
(
c_strides
);
std
::
array
<
const
void
*
,
2
>
input
=
{
a_t
.
data
(),
std
::
array
<
const
void
*
,
2
>
input
=
{
a_t
.
data
(),
b_t
.
data
()};
b_t
.
data
()};
std
::
array
<
void
*
,
1
>
output
=
{
c_t
.
data
()};
std
::
array
<
void
*
,
1
>
output
=
{
c_t
.
data
()};
auto
ck_add
=
DeviceElementwiseAddInstance
{};
auto
ck_add
=
DeviceElementwiseAddInstance
{};
auto
argument
=
ck_add
.
MakeArgumentPointer
(
auto
argument
=
ck_add
.
MakeArgumentPointer
(
lengths_abc
,
{
strides_a
,
strides_b
},
{
strides_c
},
input
,
output
,
Add
{});
lengths_abc
,
{
strides_a
,
strides_b
},
{
strides_c
},
input
,
output
,
Add
{});
using
InGrid1dDescTuple
=
decltype
(
ck_add
.
GenerateInOutGrid1dDescTuple
(
ck
::
Number
<
ndim
>
{}));
using
InGrid1dDescTuple
=
decltype
(
ck_add
.
GenerateInOutGrid1dDescTuple
(
ck
::
Number
<
ndim
>
{}));
using
OutGrid1dDescTuple
=
decltype
(
ck_add
.
GenerateInOutGrid1dDescTuple
(
ck
::
Number
<
ndim
>
{}));
using
OutGrid1dDescTuple
=
decltype
(
ck_add
.
GenerateInOutGrid1dDescTuple
(
ck
::
Number
<
ndim
>
{}));
using
InDataTypePointerTuple
=
decltype
(
ck_add
.
GenerateInDataTypePointerTuple
());
using
InDataTypePointerTuple
=
decltype
(
ck_add
.
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
ck_add
.
GenerateOutDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
ck_add
.
GenerateOutDataTypePointerTuple
());
using
GridwiseElementwise
=
ck
::
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
using
GridwiseElementwise
=
ck
::
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
ElementwiseOperation
,
MPerThread
,
MPerThread
,
InScalarPerVectorSeq
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
OutScalarPerVectorSeq
>
;
GridwiseElementwise
::
Run
(
argument
.
in_grid_1d_desc_tuple_
,
GridwiseElementwise
::
Run
(
argument
.
in_grid_1d_desc_tuple_
,
argument
.
out_grid_1d_desc_tuple_
,
argument
.
out_grid_1d_desc_tuple_
,
argument
.
in_dev_buffers_
,
argument
.
in_dev_buffers_
,
argument
.
out_dev_buffers_
,
argument
.
out_dev_buffers_
,
argument
.
elementwise_op_
);
argument
.
elementwise_op_
);
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
1dd11890
...
@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
{
{
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)));
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)));
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)),
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
}
}
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
;
DefaultBlock2CTileMap
block_2_ctile_map
;
DefaultBlock2CTileMap
block_2_ctile_map
;
if
(
true
or
GridwiseGemm
::
CheckValidity
(
if
(
true
or
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
{
{
//printf("Is valid\n");
//
printf("Is valid\n");
a_grid_desc_k0_m0_m1_k1
=
a_grid_desc_k0_m0_m1_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1
);
b_grid_desc_k0_n0_n1_k1
=
b_grid_desc_k0_n0_n1_k1
=
...
@@ -83,79 +91,86 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -83,79 +91,86 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
}
}
else
else
{
{
//printf("Not valid\n");
//
printf("Not valid\n");
}
}
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
{
{
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I2
)));
printf
(
"a_grid_desc_k0_m0_m1_k1{%i, %i, %i}
\n
"
,
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I2
)));
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
)),
printf
(
"c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}
\n
"
,
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
)),
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I1
)));
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I1
)),
int
(
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_k0_n0_n1_k1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I0
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I1
)),
int
(
b_grid_desc_k0_n0_n1_k1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}
\n
"
,
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
)),
int
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I1
)));
}
}
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
false
;
constexpr
bool
HasDoubleTailKBlockLoop
=
false
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
constexpr
bool
HasDoubleTailKBlockLoop
=
true
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
else
else
{
{
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasDoubleTailKBlockLoop
=
false
;
constexpr
bool
HasDoubleTailKBlockLoop
=
false
;
GridwiseGemm
::
Run
(
a_t
.
data
(),
GridwiseGemm
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
p_t
.
data
(),
a_grid_desc_k0_m0_m1_k1
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
block_2_ctile_map
,
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
ck
::
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm2.hpp
View file @
1dd11890
...
@@ -37,7 +37,7 @@ template <class T, class U, class V, class W>
...
@@ -37,7 +37,7 @@ template <class T, class U, class V, class W>
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
const
W
&
p_t
)
__device__
void
ck_gemm
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
,
const
W
&
p_t
)
{
{
static
gemm
tp
{};
static
gemm
tp
{};
using
GridwiseGemm
=
decltype
(
tp
.
gg
);
using
GridwiseGemm
=
decltype
(
tp
.
gg
);
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
alens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
m
=
alens
[
0
];
constexpr
auto
k
=
alens
[
1
];
constexpr
auto
k
=
alens
[
1
];
...
@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
printf
(
"%i %i %i, %i %i %i
\n
"
,
int
(
m
),
int
(
n
),
int
(
k
),
int
(
as
),
int
(
bs
),
int
(
cs
));
constexpr
auto
a_grid_desc_ak0_m_ak1
=
tp
.
MakeAGridDescriptor_AK0_M_AK1
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
constexpr
auto
a_grid_desc_ak0_m_ak1
=
tp
.
MakeAGridDescriptor_AK0_M_AK1
(
constexpr
auto
b_grid_desc_bk0_n_bk1
=
tp
.
MakeBGridDescriptor_BK0_N_BK1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
as
));
constexpr
auto
c_grid_desc_m_n
=
tp
.
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
constexpr
auto
b_grid_desc_bk0_n_bk1
=
tp
.
MakeBGridDescriptor_BK0_N_BK1
(
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
bs
));
constexpr
auto
c_grid_desc_m_n
=
tp
.
MakeCGridDescriptor_M_N
(
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
cs
));
/* constexpr */
auto
block_2_ctile_map
=
tp
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
/* constexpr */
auto
block_2_ctile_map
=
tp
.
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
);
if
(
idx
.
global
==
0
)
if
(
idx
.
global
==
0
)
{
{
printf
(
"a_grid_desc_ak0_m_ak1{%i, %i, %i}
\n
"
,
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)));
printf
(
"a_grid_desc_ak0_m_ak1{%i, %i, %i}
\n
"
,
printf
(
"b_grid_desc_bk0_n_bk1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I2
)));
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)),
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)),
int
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)));
printf
(
"b_grid_desc_bk0_n_bk1{%i, %i, %i}
\n
"
,
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)),
int
(
b_grid_desc_bk0_n_bk1
.
GetLength
(
I2
)));
printf
(
"c_grid_desc_m_n{%i, %i}
\n
"
,
int
(
c_grid_desc_m_n
.
GetLength
(
I0
)),
int
(
c_grid_desc_m_n
.
GetLength
(
I1
)));
}
}
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
{};
c_grid_desc_mblock_mperblock_nblock_nperblock
{};
if
(
true
or
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
if
(
true
or
b_grid_desc_bk0_n_bk1
,
GridwiseGemm
::
CheckValidity
(
c_grid_desc_m_n
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
))
block_2_ctile_map
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
// if(idx.global == 0)
// if(idx.global == 0)
// {
// {
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2)));
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n",
// printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)),
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// int(a_grid_desc_k0_m0_m1_k1.GetLength(I2))); printf("b_grid_desc_k0_n0_n1_k1{%i, %i,
// %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)),
// int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n",
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)),
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// }
// }
const
auto
K
=
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
auto
a_element_op
=
tp
.
a_element_op
;
auto
a_element_op
=
tp
.
a_element_op
;
auto
b_element_op
=
tp
.
b_element_op
;
auto
b_element_op
=
tp
.
b_element_op
;
auto
c_element_op
=
tp
.
c_element_op
;
auto
c_element_op
=
tp
.
c_element_op
;
...
@@ -93,31 +106,31 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
...
@@ -93,31 +106,31 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
{
{
constexpr
bool
HasMainKBlockLoop
=
true
;
constexpr
bool
HasMainKBlockLoop
=
true
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
p_t
.
data
(),
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
}
}
else
else
{
{
constexpr
bool
HasMainKBlockLoop
=
false
;
constexpr
bool
HasMainKBlockLoop
=
false
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a_t
.
data
(),
b_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
c_t
.
data
(),
p_t
.
data
(),
p_t
.
data
(),
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
}
}
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_includes.hpp
View file @
1dd11890
This diff is collapsed.
Click to expand it.
src/targets/gpu/kernels/include/migraphx/kernels/ck_includes.hpp
View file @
1dd11890
...
@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
...
@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Row
;
//Col;
using
ALayout
=
Row
;
//
Col;
using
BLayout
=
Row
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
...
@@ -216,39 +216,39 @@ using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
...
@@ -216,39 +216,39 @@ using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
GridwiseGemm
=
using
GridwiseGemm
=
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ck
::
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
ADataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
ck
::
InMemoryDataOperationEnum
::
Set
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
M1PerThread
,
M1PerThread
,
N1PerThread
,
N1PerThread
,
KPerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
CThreadTransferDstScalarPerVector
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 =
...
@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 =
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
test/verify/0ack_test_ck_gemm.cpp
View file @
1dd11890
...
@@ -37,8 +37,10 @@
...
@@ -37,8 +37,10 @@
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
...
@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
...
@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
2
,
3
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
3
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
3
,
4
}};
std
::
vector
<
float
>
v1
(
2
*
3
,
1
);
std
::
vector
<
float
>
v1
(
2
*
3
,
1
);
std
::
iota
(
v1
.
begin
(),
v1
.
end
(),
1
);
std
::
iota
(
v1
.
begin
(),
v1
.
end
(),
1
);
std
::
vector
<
float
>
v2
(
3
*
4
,
1
);
std
::
vector
<
float
>
v2
(
3
*
4
,
1
);
//std::iota(v2.begin(), v2.end(), 1);
//
std::iota(v2.begin(), v2.end(), 1);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
v2
});
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
v2
});
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
//l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
//
l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_gemm"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_gemm"
),
l1
,
l2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment