Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d378877
Commit
8d378877
authored
Aug 25, 2022
by
turneram
Browse files
Switch to elementwise
parent
b41a56cf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
161 additions
and
85 deletions
+161
-85
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/migraphx/op/ck_elementwise.hpp
src/include/migraphx/op/ck_elementwise.hpp
+24
-31
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+127
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
+0
-45
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+4
-3
test/verify/0ck_elementwise_test.cpp
test/verify/0ck_elementwise_test.cpp
+4
-4
No files found.
src/CMakeLists.txt
View file @
8d378877
...
...
@@ -118,7 +118,7 @@ register_migraphx_ops(
broadcast
capture
ceil
ck_
gemm
ck_
elementwise
clip
concat
contiguous
...
...
src/include/migraphx/op/ck_
gemm
.hpp
→
src/include/migraphx/op/ck_
elementwise
.hpp
View file @
8d378877
...
...
@@ -21,59 +21,52 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_
GEMM
_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_
GEMM
_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_
ELEMENTWISE
_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_
ELEMENTWISE
_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
gemm
.hpp>
#include <migraphx/
par_for
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
ck_
gemm
struct
ck_
elementwise
{
std
::
string
name
()
const
{
return
"ck_
gemm
"
;
}
std
::
string
name
()
const
{
return
"ck_
elementwise
"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
)
;
return
s0
;
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
else
if
(
s0
.
packed
()
!=
s1
.
packed
())
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
return
s0
.
packed
()
?
s0
:
s1
;
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
else
if
(
s0
.
broadcasted
()
!=
s1
.
broadcasted
())
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
return
s0
.
broadcasted
()
?
s1
.
with_lens
(
s0
.
lens
())
:
s0
.
with_lens
(
s0
.
lens
());
}
else
{
return
{
s0
.
type
(),
s0
.
lens
()};
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
=
argument
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
const
auto
i
)
{
output
[
i
]
=
input1
[
i
]
+
input2
[
i
];
});
});
return
result
;
}
};
...
...
src/include/migraphx/operators.hpp
View file @
8d378877
...
...
@@ -40,7 +40,7 @@
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_
gemm
.hpp>
#include <migraphx/op/ck_
elementwise
.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
...
...
src/targets/gpu/jit/ck_
gemm
.cpp
→
src/targets/gpu/jit/ck_
elementwise
.cpp
View file @
8d378877
...
...
@@ -40,139 +40,57 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
ck_
gemm
_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_
gemm
.hpp>
static
const
char
*
const
ck_
elementwise
_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_
elementwise
.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
namespace migraphx {
extern "C" {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
using F16 = ck::half_t;
using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
1,
8,
8,
8,
8>;
ck::index_t M = 1024;
std::array<const void*, 2> input = {a_p,
b_p};
std::array<void*, 1> output = {c_p};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
}
}
...
...
@@ -181,9 +99,9 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
struct
ck_
gemm
_compiler
:
compiler
<
ck_
gemm
_compiler
>
struct
ck_
elementwise
_compiler
:
compiler
<
ck_
elementwise
_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_
gemm
"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_
elementwise
"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
...
...
@@ -192,10 +110,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"ck_
gemm
_kernel"
;
options
.
kernel_name
=
"ck_
elementwise
_kernel"
;
options
.
virtual_inputs
=
inputs
;
return
compile_hip_code_object
(
ck_
gemm
_kernel
,
options
);
return
compile_hip_code_object
(
ck_
elementwise
_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
deleted
100644 → 0
View file @
b41a56cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
\ No newline at end of file
src/targets/gpu/kernels/include/migraphx/kernels/ck_
gemm
.hpp
→
src/targets/gpu/kernels/include/migraphx/kernels/ck_
elementwise
.hpp
View file @
8d378877
...
...
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_
GEMM
_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_
GEMM
_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_CK_
ELEMENTWISE
_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_
ELEMENTWISE
_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
...
...
@@ -30,9 +30,10 @@
namespace
migraphx
{
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_
gemm
(
const
T
&
/* data_t */
,
const
U
&
/* indices_t */
,
const
V
&
/* output_t */
)
__device__
void
ck_
elementwise
(
const
T
&
/* data_t */
,
const
U
&
/* indices_t */
,
const
V
&
/* output_t */
)
{
}
}
// namespace migraphx
#endif
test/verify/0ck_
gemm
_test.cpp
→
test/verify/0ck_
elementwise
_test.cpp
View file @
8d378877
...
...
@@ -27,18 +27,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_
gemm
:
verify_program
<
ck_
gemm
>
struct
ck_
elementwise
:
verify_program
<
ck_
elementwise
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
10
,
20
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
20
,
10
}};
//
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m
2
_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m
1
_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_
gemm
"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_
elementwise
"
),
l1
,
l2
);
return
p
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment