Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
961cf059
Commit
961cf059
authored
Sep 16, 2022
by
turneram
Browse files
Remove ck from cmakelists
parent
2593dd60
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
188 additions
and
167 deletions
+188
-167
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-3
src/targets/gpu/jit/ck_elementwise.cpp
src/targets/gpu/jit/ck_elementwise.cpp
+8
-109
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+26
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
+0
-45
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
...s/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
+83
-9
test/verify/0ck_elementwise_half_test.cpp
test/verify/0ck_elementwise_half_test.cpp
+69
-1
test/verify/0ck_test_ck_gemm.cpp
test/verify/0ck_test_ck_gemm.cpp
+0
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
961cf059
...
...
@@ -25,7 +25,6 @@
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc /opt/rocm/ck
)
find_package
(
miopen
)
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
# rocblas
...
...
@@ -393,8 +392,8 @@ endif()
# Workaround broken rocblas headers
target_compile_definitions
(
migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
composable_kernel::device_operations
)
target_include_directories
(
migraphx_gpu PRIVATE /opt/rocm/ck/include
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
#
target_include_directories(migraphx_gpu PRIVATE /opt/rocm/ck/include)
add_subdirectory
(
driver
)
...
...
src/targets/gpu/jit/ck_elementwise.cpp
View file @
961cf059
...
...
@@ -43,111 +43,6 @@ namespace gpu {
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
// static const char* const ck_elementwise_kernel = R"__migraphx__(
// //#include <migraphx/kernels/ck_elementwise.hpp>
// #include <migraphx/kernels/ops.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/generic_constant.hpp>
// #include <args.hpp>
// #include <migraphx/kernels/index.hpp>
// #include <migraphx/kernels/algorithm.hpp>
// #include <migraphx/kernels/integral_constant.hpp>
// #include <migraphx/kernels/tensor_view.hpp>
// #include "ck/device_utility/device_prop.hpp"
// #include "ck/device_utility/kernel_launch.hpp"
// #include "ck/tensor_operation/gpu/device/device_base.hpp"
// #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
// namespace migraphx {
// using ADataType = float;
// using BDataType = float;
// using CDataType = float;
// using ElementwiseFunctor = float;
// static constexpr auto I0 = ck::Number<0>{};
// template <class L, class S, class N>
// constexpr auto MakeDescriptor_M(const L& lengths, const S& strides, const N& ndim)
// {
// auto gridSize = 72;
// auto blockSize = 1024;
// //constexpr auto ndim = 1;
// // auto idx = make_index();
// auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]);
// },
// ck::Number<ndim>{});
// auto tupleOfStride = generate_tuple(
// [&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
// const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// auto desc_m = desc;
// // merge nd to 1d desc - [s0 * s1 * ...]
// if constexpr(ndim > 1)
// {
// desc_m = transform_tensor_descriptor(
// desc,
// make_tuple(make_merge_transform(tupleOfShape)),
// make_tuple(generate_sequence_v2([&](auto I) { return I; }, ck::Number<ndim>{})),
// make_tuple(ck::Sequence<0>{}));
// }
// const auto M = desc_m.GetLength(I0);
// const ck::index_t loop_step = /* idx.nglobal(); // */ gridSize * blockSize /* * MPerThread
// */; const auto pad = ck::math::integer_least_multiple(M, loop_step) - M; const
// auto desc_m_pad =
// transform_tensor_descriptor(desc_m,
// make_tuple(ck::make_right_pad_transform(M, pad)),
// make_tuple(ck::Sequence<0>{}),
// make_tuple(ck::Sequence<0>{}));
// return desc_m_pad;
// }
// struct Add
// {
// template <typename Y, typename X0, typename X1>
// __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
// {
// y = x0 + x1;
// };
// };
// extern "C" {
// __global__ void ck_elementwise_kernel(void* a_p, void* b_p, void* c_p)
// {
// make_tensors()(a_p, b_p, c_p)([](auto a_t, auto b_t, auto c_t) {
// constexpr auto lengths = get_shape_c<decltype(a_t)>{}.lens;
// constexpr auto strides = get_shape_c<decltype(a_t)>{}.strides;
// constexpr auto ndim = _c<decltype(lengths.size()){}>[1];
// constexpr auto a_desc = MakeDescriptor_M(lengths, strides, ndim);
// using AGridDesc_M = decltype(a_desc);
// using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
// BDataType,
// CDataType,
// CDataType,
// AGridDesc_M,
// AGridDesc_M,
// AGridDesc_M,
// Add,
// 1,
// 1,
// 1,
// 1>;
// auto op = Add{};
// GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, op);
// });
// }
// }
// } // namespace migraphx
// )__migraphx__";
// NOLINTNEXTLINE
static
const
char
*
const
ck_elementwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/ck_elementwise.hpp>
...
...
@@ -190,11 +85,15 @@ struct ck_elementwise_compiler : compiler<ck_elementwise_compiler>
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
//options.virtual_inputs = reduce_dims(inputs);
//std::cout << options.virtual_inputs << std::endl;
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
options
.
virtual_inputs
);
// auto axis = find_fast_axis(options.virtual_inputs);
// auto vec = vectorize::elements(axis, options.virtual_inputs);
// auto preloads = preload::broadcasts(axis, options.virtual_inputs);
auto
axis
=
find_fast_axis
(
inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
options
.
kernel_name
=
"ck_elementwise_kernel"
;
options
.
set_launch_params
(
v
,
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
961cf059
...
...
@@ -167,6 +167,32 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
// using AGridDesc_K0_M0_M1_K1 =
// decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
// using BGridDesc_K0_N0_N1_K1 =
// decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
// using CGridDesc_M0_M10_M11_N0_N10_N11 =
// decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
// using DefaultBlock2CTileMap =
// decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// const auto kernel = ck::kernel_gemm_dl_v1r3<GridwiseGemm,
// ADataType,
// CDataType,
// remove_reference_t<AGridDesc_K0_M0_M1_K1>,
// remove_reference_t<BGridDesc_K0_N0_N1_K1>,
// remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
// remove_reference_t<DefaultBlock2CTileMap>,
// true,
// true>;
// kernel(a_t.data(),
// b_t.data(),
// c_t.data(),
// a_grid_desc_k0_m0_m1_k1,
// b_grid_desc_k0_n0_n1_k1,
// c_grid_desc_m0_m10_m11_n0_n10_n11,
// block_2_ctile_map);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_device_gemm.hpp
deleted
100644 → 0
View file @
2593dd60
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
\ No newline at end of file
src/targets/gpu/kernels/include/migraphx/kernels/ck_elementwise.hpp
View file @
961cf059
...
...
@@ -24,6 +24,8 @@
#ifndef MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_ELEMENTWISE_HPP
#include <stdio.h>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
...
...
@@ -48,7 +50,7 @@ template <ck::index_t ndim>
struct
CKBinaryElementwise
{
template
<
class
Desc_M
>
constexpr
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
)
__device__
constexpr
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
)
{
auto
gridSize
=
72
;
auto
blockSize
=
1024
;
...
...
@@ -65,7 +67,7 @@ struct CKBinaryElementwise
}
template
<
class
L
,
class
S
>
constexpr
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
)
__device__
constexpr
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
)
{
auto
tupleOfShape
=
generate_tuple
(
[
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
lengths
[
I
]);
},
ck
::
Number
<
ndim
>
{});
...
...
@@ -89,6 +91,51 @@ struct CKBinaryElementwise
}
};
template
<
ck
::
index_t
ndim
>
struct
CKBinaryElementwise2
{
template
<
class
Desc_M
>
/* constexpr */
__device__
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
)
{
auto
gridSize
=
72
;
auto
blockSize
=
1024
;
auto
MPerThread
=
8
;
const
auto
M
=
desc_m
.
GetLength
(
I0
);
const
ck
::
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
ck
::
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
ck
::
make_right_pad_transform
(
M
,
pad
)),
make_tuple
(
ck
::
Sequence
<
0
>
{}),
make_tuple
(
ck
::
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
template
<
class
L
,
class
S
>
/* constexpr */
__device__
auto
MakeDescriptor_M
(
const
L
&
lengths
,
const
S
&
strides
)
{
auto
tupleOfShape
=
generate_tuple
(
[
&
](
auto
I
)
{
return
static_cast
<
ck
::
index_t
>
(
lengths
[
I
]);
},
ck
::
Number
<
ndim
>
{});
auto
tupleOfStride
=
generate_tuple
(
[
&
](
auto
I
)
{
printf
(
"Stride %i: %i
\n
"
,
int
(
I
),
int
(
strides
[
I
]));
return
static_cast
<
ck
::
index_t
>
(
strides
[
I
]);
},
ck
::
Number
<
ndim
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
ndim
>
1
)
{
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
ck
::
Number
<
ndim
>
{})),
make_tuple
(
ck
::
Sequence
<
0
>
{}));
return
PadDescriptor_M_1d
(
desc_m
);
}
else
{
return
PadDescriptor_M_1d
(
desc
);
}
}
};
struct
Add
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
...
...
@@ -98,30 +145,57 @@ struct Add
};
};
struct
Mul
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
x0
*
x1
;
};
};
struct
Div
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
x0
/
x1
;
};
};
template
<
class
T
,
class
U
,
class
V
>
__device__
void
ck_elementwise
(
const
T
&
a_t
,
const
U
&
b_t
,
const
V
&
c_t
)
{
//auto idx = make_index();
constexpr
auto
a_lens
=
get_shape_c
<
T
>
{}.
lens
;
constexpr
auto
a_strides
=
get_shape_c
<
T
>
{}.
strides
;
constexpr
ck
::
index_t
a_ndim
=
decltype
(
a_lens
.
size
()){};
constexpr
ck
::
index_t
a_ndim
=
a_lens
.
size
();
//decltype(a_lens.size()){};
// if (idx.global == 0)
// printf("a_ndim: %i\n", int(a_ndim));
auto
a_bin_op
=
CKBinaryElementwise
<
a_ndim
>
{};
constexpr
auto
a_desc
=
a_bin_op
.
MakeDescriptor_M
(
a_lens
,
a_strides
);
constexpr
auto
b_lens
=
get_shape_c
<
U
>
{}.
lens
;
constexpr
auto
b_strides
=
get_shape_c
<
U
>
{}.
strides
;
constexpr
ck
::
index_t
b_ndim
=
decltype
(
b_lens
.
size
()){};
constexpr
ck
::
index_t
b_ndim
=
b_lens
.
size
();
//decltype(b_lens.size()){};
// if (idx.global == 0)
// printf("b_ndim: %i\n", int(b_ndim));
auto
b_bin_op
=
CKBinaryElementwise
<
b_ndim
>
{};
constexpr
auto
b_desc
=
b_bin_op
.
MakeDescriptor_M
(
b_lens
,
b_strides
);
constexpr
auto
c_lens
=
get_shape_c
<
V
>
{}.
lens
;
constexpr
auto
c_strides
=
get_shape_c
<
V
>
{}.
strides
;
constexpr
ck
::
index_t
c_ndim
=
decltype
(
c_lens
.
size
()){};
constexpr
ck
::
index_t
c_ndim
=
c_lens
.
size
();
//
decltype(c_lens.size()){};
auto
c_bin_op
=
CKBinaryElementwise
<
c_ndim
>
{};
constexpr
auto
c_desc
=
c_bin_op
.
MakeDescriptor_M
(
c_lens
,
c_strides
);
using
AGridDesc_M
=
decltype
(
a_desc
);
using
BGridDesc_M
=
decltype
(
b_desc
);
using
CGridDesc_M
=
decltype
(
c_desc
);
constexpr
ck
::
index_t
MPerThread
=
8
;
constexpr
ck
::
index_t
AScalarPerVector
=
8
;
constexpr
ck
::
index_t
BScalarPerVector
=
8
;
constexpr
ck
::
index_t
CScalarPerVector
=
8
;
using
GridwiseBinEltwise
=
ck
::
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
...
...
@@ -130,10 +204,10 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
BGridDesc_M
,
CGridDesc_M
,
Add
,
8
,
8
,
8
,
8
>
;
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
auto
op
=
Add
{};
GridwiseBinEltwise
::
Run
(
a_t
.
data
(),
b_t
.
data
(),
c_t
.
data
(),
a_desc
,
b_desc
,
c_desc
,
op
);
}
...
...
test/verify/0ck_elementwise_half_test.cpp
View file @
961cf059
...
...
@@ -27,6 +27,53 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// struct ck_elementwise_half : verify_program<ck_elementwise_half>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {1, 1, 2, 2, 2}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1, 1, 1, 2, 1}};
// std::vector<float> v1(8, 1);
// std::vector<float> v2(2);
// std::iota(v2.begin(), v2.end(), 1);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// l2 = mm->add_instruction(
// migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 2, 2, 2}}}), l2);
// //l2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
// mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
// mm->debug_print();
// return p;
// }
// };
// struct ck_elementwise_half : verify_program<ck_elementwise_half>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {2, 384, 3072}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1, 384, 1}};
// std::vector<float> v1(2*384*3072, 1);
// std::vector<float> v2(384, 2.54);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// l2 = mm->add_instruction(
// migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), l2);
// //l2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
// mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
// mm->debug_print();
// return p;
// }
// };
struct
ck_elementwise_half
:
verify_program
<
ck_elementwise_half
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -34,7 +81,7 @@ struct ck_elementwise_half : verify_program<ck_elementwise_half>
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
2
,
384
,
3072
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
3072
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
384
,
1
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
l2
=
mm
->
add_instruction
(
...
...
@@ -45,3 +92,24 @@ struct ck_elementwise_half : verify_program<ck_elementwise_half>
return
p
;
}
};
// struct ck_elementwise_half : verify_program<ck_elementwise_half>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {3072}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1}};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape);
// l2 = mm->add_instruction(
// migraphx::make_op("multibroadcast", {{"out_lens", {3072}}}), l2);
// //l2 = mm->add_instruction(migraphx::make_op("contiguous"), l2);
// mm->add_instruction(migraphx::make_op("ck_elementwise"), l1, l2);
// //mm->debug_print();
// return p;
// }
// };
test/verify/0
a
_test_ck_gemm.cpp
→
test/verify/0
ck
_test_ck_gemm.cpp
View file @
961cf059
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment