Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d378877
Commit
8d378877
authored
Aug 25, 2022
by
turneram
Browse files
Switch to elementwise
parent
b41a56cf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
161 additions
and
85 deletions
+161
-85
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/migraphx/op/ck_elementwise.hpp
src/include/migraphx/op/ck_elementwise.hpp
+24
-31
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+127
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
+0
-45
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+4
-3
test/verify/0ck_elementwise_test.cpp
test/verify/0ck_elementwise_test.cpp
+4
-4
No files found.
src/CMakeLists.txt
View file @
8d378877
...
...
@@ -118,7 +118,7 @@ register_migraphx_ops(
broadcast
capture
ceil
ck_
gemm
ck_
elementwise
clip
concat
contiguous
...
...
src/include/migraphx/op/ck_
gemm
.hpp
→
src/include/migraphx/op/ck_
elementwise
.hpp
View file @
8d378877
...
...
@@ -21,59 +21,52 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_
GEMM
_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_
GEMM
_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_
ELEMENTWISE
_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_
ELEMENTWISE
_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
gemm
.hpp>
#include <migraphx/
par_for
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
ck_
gemm
struct
ck_
elementwise
{
std
::
string
name
()
const
{
return
"ck_
gemm
"
;
}
std
::
string
name
()
const
{
return
"ck_
elementwise
"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
!
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dims operands"
)
;
return
s0
;
}
// only handle the case that the batch size of a and b are the same
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
else
if
(
s0
.
packed
()
!=
s1
.
packed
())
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
return
s0
.
packed
()
?
s0
:
s1
;
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
else
if
(
s0
.
broadcasted
()
!=
s1
.
broadcasted
())
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
return
s0
.
broadcasted
()
?
s1
.
with_lens
(
s0
.
lens
())
:
s0
.
with_lens
(
s0
.
lens
());
}
else
{
return
{
s0
.
type
(),
s0
.
lens
()};
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
=
argument
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
const
auto
i
)
{
output
[
i
]
=
input1
[
i
]
+
input2
[
i
];
});
});
return
result
;
}
};
...
...
src/include/migraphx/operators.hpp
View file @
8d378877
...
...
@@ -40,7 +40,7 @@
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_
gemm
.hpp>
#include <migraphx/op/ck_
elementwise
.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
...
...
src/targets/gpu/jit/ck_
gemm
.cpp
→
src/targets/gpu/jit/ck_
elementwise
.cpp
View file @
8d378877
...
...
@@ -40,139 +40,57 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
ck_
gemm
_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_
gemm
.hpp>
static
const
char
*
const
ck_
elementwise
_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_
elementwise
.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
namespace migraphx {
extern "C" {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
__global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
{
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp,
ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
using F16 = ck::half_t;
using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
1,
8,
8,
8,
8>;
ck::index_t M = 1024;
std::array<const void*, 2> input = {a_p,
b_p};
std::array<void*, 1> output = {c_p};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
}
}
...
...
@@ -181,9 +99,9 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"
;
struct
ck_
gemm
_compiler
:
compiler
<
ck_
gemm
_compiler
>
struct
ck_
elementwise
_compiler
:
compiler
<
ck_
elementwise
_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_
gemm
"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_
elementwise
"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
...
...
@@ -192,10 +110,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"ck_
gemm
_kernel"
;
options
.
kernel_name
=
"ck_
elementwise
_kernel"
;
options
.
virtual_inputs
=
inputs
;
return
compile_hip_code_object
(
ck_
gemm
_kernel
,
options
);
return
compile_hip_code_object
(
ck_
elementwise
_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
deleted
100644 → 0
View file @
b41a56cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
\ No newline at end of file
src/targets/gpu/kernels/include/migraphx/kernels/ck_
gemm
.hpp
→
src/targets/gpu/kernels/include/migraphx/kernels/ck_
elementwise
.hpp
View file @
8d378877
...
...
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_
GEMM
_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_
GEMM
_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_CK_
ELEMENTWISE
_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_
ELEMENTWISE
_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
...
...
@@ -30,9 +30,10 @@
namespace
migraphx
{
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_
gemm
(
const
T
&
/* data_t */
,
const
U
&
/* indices_t */
,
const
V
&
/* output_t */
)
__device__
void
ck_
elementwise
(
const
T
&
/* data_t */
,
const
U
&
/* indices_t */
,
const
V
&
/* output_t */
)
{
}
}
// namespace migraphx
#endif
test/verify/0ck_
gemm
_test.cpp
→
test/verify/0ck_
elementwise
_test.cpp
View file @
8d378877
...
...
@@ -27,18 +27,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_
gemm
:
verify_program
<
ck_
gemm
>
struct
ck_
elementwise
:
verify_program
<
ck_
elementwise
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
10
,
20
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
20
,
10
}};
//
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m
2
_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m
1
_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_
gemm
"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_
elementwise
"
),
l1
,
l2
);
return
p
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment