Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cd4ab535
"app/vscode:/vscode.git/clone" did not exist on "558a54b098dc5044f5c4167ede5c327c970185a2"
Commit
cd4ab535
authored
Jun 20, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
3891ee58
a0fa3742
Changes
279
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
432 additions
and
244 deletions
+432
-244
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+19
-2
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
...ets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
+92
-0
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+0
-4
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+14
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+5
-15
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+52
-46
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+11
-3
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+2
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+137
-26
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+2
-17
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+22
-12
src/targets/gpu/time_op.cpp
src/targets/gpu/time_op.cpp
+1
-3
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+1
-2
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+2
-108
src/tf/op_parser.cpp
src/tf/op_parser.cpp
+1
-0
src/tf/tf.cpp
src/tf/tf.cpp
+3
-0
src/value.cpp
src/value.cpp
+33
-0
test/CMakeLists.txt
test/CMakeLists.txt
+32
-2
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
cd4ab535
...
@@ -32,8 +32,17 @@
...
@@ -32,8 +32,17 @@
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template <class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lifts_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...)) \
}
namespace
migraphx
{
namespace
migraphx
{
...
@@ -195,6 +204,14 @@ constexpr auto compose(Fs... fs)
...
@@ -195,6 +204,14 @@ constexpr auto compose(Fs... fs)
})(
fs
...);
})(
fs
...);
}
}
template
<
class
F
>
constexpr
auto
partial
(
F
f
)
{
return
[
=
](
auto
...
xs
)
{
return
[
=
](
auto
&&
...
ys
)
{
return
f
(
xs
...,
static_cast
<
decltype
(
ys
)
>
(
ys
)...);
};
};
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/index.hpp>
namespace
migraphx
{
template
<
class
Tensor
>
constexpr
auto
gemm_get_batches
()
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
});
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
>
constexpr
auto
gemm_get_matrix
()
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
constexpr
auto
new_strides
=
make_const_array
(
_c
<
strides
[
m
]
>
,
_c
<
strides
[
n
]
>
);
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
,
class
T
>
constexpr
auto
gemm_batch_slice
(
Tensor
t
,
T
i
)
{
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
MIGRAPHX_ASSERT
((
batch
.
index
(
i
)
+
matrix
.
element_space
())
<=
t
.
get_shape
().
element_space
());
return
make_tensor_view
(
t
.
data
()
+
batch
.
index
(
i
),
matrix
);
}
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
constexpr
auto
gemm_batch_args
(
index
idx
,
BlocksPerBatch
bpb
,
T
x
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
// All tensors should have the same rank
static_assert
(
(
true
and
...
and
(
get_shape_c
<
T
>
{}.
lens
.
size
()
==
get_shape_c
<
Ts
>
{}.
lens
.
size
())));
if
constexpr
(
get_shape_c
<
T
>
{}.
lens
.
size
()
>
2
)
{
// Get the first batch since all batches should have the same number of elements
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
static_assert
(
(
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
const
auto
batch_idx
=
gidx
/
bpb
;
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
});
}
else
{
f
(
x
,
xs
...);
}
};
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
cd4ab535
...
@@ -28,10 +28,6 @@
...
@@ -28,10 +28,6 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
cd4ab535
...
@@ -130,6 +130,8 @@ struct index
...
@@ -130,6 +130,8 @@ struct index
return
blockDim
.
x
;
return
blockDim
.
x
;
}
}
#endif
#endif
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
template
<
class
N
,
class
Stride
>
template
<
class
N
,
class
Stride
>
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
{
{
...
@@ -231,8 +233,20 @@ struct index
...
@@ -231,8 +233,20 @@ struct index
{
{
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
}
}
template
<
class
F
,
class
N
>
__device__
void
group_stride
(
N
n
,
F
f
)
const
{
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
}
};
};
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline
__device__
__attribute__
((
const
))
index
make_index
()
inline
__device__
__attribute__
((
const
))
index
make_index
()
{
{
return
index
{
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
threadIdx
.
x
,
blockIdx
.
x
};
// NOLINT
return
index
{
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
threadIdx
.
x
,
blockIdx
.
x
};
// NOLINT
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
cd4ab535
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
//
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sin
,
::
hsin
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
// Use float to compute half overload
...
@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
...
@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// Most but not all of these math ops have operators of the same names.
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
...
@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
...
@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
//
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
...
@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
...
@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
__hmax
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
fmaxf
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
__hmin
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
fminf
)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
...
@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b)
...
@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
View file @
cd4ab535
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__device__
void
println
(
Ts
...
xs
)
__device__
void
println
(
Ts
...
xs
)
{
{
print_each
(
&
cout
ln
,
xs
...);
print_each
(
&
cout
,
xs
...
,
'\n'
);
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__device__
void
println_once
(
Ts
...
xs
)
__device__
void
println_once
(
Ts
...
xs
)
{
{
print_each_once
(
&
cout
ln
,
xs
...);
print_each_once
(
&
cout
,
xs
...
,
'\n'
);
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
cd4ab535
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix
)
\
#define MIGRAPHX_DPP_REDUCE(op, prefix
, sign)
\
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
_u32);
\
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
sign##32);
\
} \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
)
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
...
@@ -174,6 +175,25 @@ struct inner_storage_tag
...
@@ -174,6 +175,25 @@ struct inner_storage_tag
template
<
class
T
>
template
<
class
T
>
using
is_inner_storage
=
is_base_of
<
inner_storage_tag
,
remove_cv_t
<
remove_reference_t
<
T
>>>
;
using
is_inner_storage
=
is_base_of
<
inner_storage_tag
,
remove_cv_t
<
remove_reference_t
<
T
>>>
;
template
<
class
Size
,
class
F
>
struct
lazy_inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
constexpr
lazy_inner_storage
<
Size
,
F
>
make_lazy_inner_storage
(
Size
,
F
f
)
{
return
{{},
f
};
}
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
struct
storage_access
:
F
struct
storage_access
:
F
{
{
...
@@ -278,6 +298,14 @@ struct reducer_base
...
@@ -278,6 +298,14 @@ struct reducer_base
});
});
}
}
template
<
class
F
>
__device__
auto
lazy_inner
(
F
f
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
return
make_lazy_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
});
}
template
<
class
Op
,
class
T
,
class
Read
>
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
{
...
@@ -396,25 +424,6 @@ struct block_large
...
@@ -396,25 +424,6 @@ struct block_large
index
idx
;
index
idx
;
Slicer
slice
;
Slicer
slice
;
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
static
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{{},
{
f
}};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
{
...
@@ -439,7 +448,7 @@ struct block_large
...
@@ -439,7 +448,7 @@ struct block_large
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
{
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
return
make_
lazy_
inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
}
};
};
...
@@ -469,25 +478,6 @@ struct lane
...
@@ -469,25 +478,6 @@ struct lane
index
idx
;
index
idx
;
Slicer
slice
;
Slicer
slice
;
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
static
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{{},
{
f
}};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
U
,
class
...
Us
>
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
U
,
class
...
Us
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
U
&&
x
,
Us
&&
...
xs
)
const
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
U
&&
x
,
Us
&&
...
xs
)
const
{
{
...
@@ -518,7 +508,7 @@ struct lane
...
@@ -518,7 +508,7 @@ struct lane
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
{
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
return
make_
lazy_
inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -577,5 +567,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
...
@@ -577,5 +567,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
});
});
}
}
template
<
class
Algo
,
class
Reduced
,
class
Output
,
class
F
>
__device__
void
fused_reduce
(
Output
output
,
F
f
)
{
Algo
::
template
run
<
Reduced
>([
&
](
auto
out_idx
,
auto
r
)
{
auto
result
=
f
(
r
,
out_idx
);
if
constexpr
(
reduce
::
is_inner_storage
<
decltype
(
result
)
>
{})
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
;
})(
output
,
result
);
}
else
{
r
.
outer
([
&
]
{
output
[
out_idx
]
=
implicit_conversion
(
result
);
});
}
});
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
cd4ab535
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
// Note, left shift cannot be used to get the maximum value of int64_type or
// uint64_type because it is undefined behavior to left shift 64 bits for
// these types
if
(
n
==
sizeof
(
int64_t
))
return
-
1
;
return
(
1ul
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
,
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
{
{
if
constexpr
(
is_unsigned
<
T
>
{})
if
constexpr
(
is_unsigned
<
T
>
{})
return
int_max
(
sizeof
(
T
))
*
2
;
else
return
int_max
(
sizeof
(
T
));
return
int_max
(
sizeof
(
T
));
else
return
int_max
(
sizeof
(
T
))
/
2
;
}
}
else
if
constexpr
(
is_same
<
T
,
double
>
{})
else
if
constexpr
(
is_same
<
T
,
double
>
{})
return
__DBL_MAX__
;
return
__DBL_MAX__
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
cd4ab535
...
@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
...
@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return
vec
<
T
,
N
>
{
x
};
return
vec
<
T
,
N
>
{
x
};
else
else
{
{
MIGRAPHX_ASSERT
((
i
+
N
)
<
vec_size
<
T
>
());
MIGRAPHX_ASSERT
((
i
+
N
)
<
=
vec_size
<
T
>
());
vec
<
vec_type
<
T
>
,
N
>
result
=
{
0
};
vec
<
vec_type
<
T
>
,
N
>
result
=
{
0
};
for
(
int
j
=
0
;
j
<
N
;
j
++
)
for
(
int
j
=
0
;
j
<
N
;
j
++
)
{
{
...
...
src/targets/gpu/lowering.cpp
View file @
cd4ab535
...
@@ -83,7 +83,8 @@ struct miopen_apply
...
@@ -83,7 +83,8 @@ struct miopen_apply
auto
&
ctx
=
get_context
();
auto
&
ctx
=
get_context
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
int8_x4_format
=
get_int8_x4_format
(
ctx
);
compute_fp32
=
get_compute_fp32_flag
();
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
// TODO: Set Offload copy based on root modules' compile options
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
add_generic_op
(
"contiguous"
);
...
...
src/targets/gpu/mlir.cpp
View file @
cd4ab535
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Pass.h>
#include <mutex>
#include <mutex>
...
@@ -55,12 +56,16 @@
...
@@ -55,12 +56,16 @@
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <deque>
#include <deque>
#include <variant>
#include <variant>
#include <fstream>
#include <sstream>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_DB
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_CFG
);
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
template
<
class
T
,
class
F
,
F
f
>
// NOLINT
template
<
class
T
,
class
F
,
F
f
>
// NOLINT
...
@@ -124,6 +129,8 @@ using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
...
@@ -124,6 +129,8 @@ using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
using
mlir_region
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRegion
,
mlirRegionDestroy
);
using
mlir_region
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRegion
,
mlirRegionDestroy
);
using
mlir_block
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirBlock
,
mlirBlockDestroy
);
using
mlir_block
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirBlock
,
mlirBlockDestroy
);
using
mlir_pass_manager
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirPassManager
,
mlirPassManagerDestroy
);
using
mlir_pass_manager
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirPassManager
,
mlirPassManagerDestroy
);
using
mlir_tuning_table
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningTable
,
mlirRockTuningTableDestroy
);
std
::
string_view
to_string_view
(
MlirStringRef
s
)
{
return
{
s
.
data
,
s
.
length
};
}
std
::
string_view
to_string_view
(
MlirStringRef
s
)
{
return
{
s
.
data
,
s
.
length
};
}
...
@@ -157,10 +164,10 @@ std::string mlir_print(F f, T x)
...
@@ -157,10 +164,10 @@ std::string mlir_print(F f, T x)
return
ss
.
str
();
return
ss
.
str
();
}
}
const
std
::
unordered_set
<
std
::
string
>
&
get_
xdlops_
arch
s
(
)
bool
has_xdlops
(
const
std
::
string
&
tar
get_arch
)
{
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
}
;
const
auto
device_name
=
trim
(
split_string
(
target_arch
,
':'
).
front
())
;
return
supported_archs
;
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
)
;
}
}
struct
mlir_program
struct
mlir_program
...
@@ -190,10 +197,14 @@ struct mlir_program
...
@@ -190,10 +197,14 @@ struct mlir_program
result
=
mlirF64TypeGet
(
ctx
.
get
());
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
else
if
(
as
.
is_integral
())
{
{
if
(
as
.
is_signed
())
// Note: rocMLIR use signless integer type for tensors types. This
result
=
mlirIntegerTypeSignedGet
(
ctx
.
get
(),
as
.
size
()
*
8
);
// will translate to signed implementation for current supported
else
// operations.
result
=
mlirIntegerTypeGet
(
ctx
.
get
(),
as
.
size
()
*
8
);
if
(
as
.
is_unsigned
())
{
MIGRAPHX_THROW
(
"Unsupported type: "
+
std
::
to_string
(
as
.
type_enum
()));
}
result
=
mlirIntegerTypeGet
(
ctx
.
get
(),
as
.
size
()
*
8
);
}
}
else
else
MIGRAPHX_THROW
(
"Unsupported type: "
+
std
::
to_string
(
as
.
type_enum
()));
MIGRAPHX_THROW
(
"Unsupported type: "
+
std
::
to_string
(
as
.
type_enum
()));
...
@@ -313,7 +324,8 @@ struct mlir_program
...
@@ -313,7 +324,8 @@ struct mlir_program
std
::
string
,
std
::
string
,
value
,
value
,
std
::
vector
<
value
>
,
std
::
vector
<
value
>
,
MlirType
>
;
MlirType
,
MlirAttribute
>
;
using
named_attribute_t
=
std
::
pair
<
std
::
string_view
,
attribute_t
>
;
using
named_attribute_t
=
std
::
pair
<
std
::
string_view
,
attribute_t
>
;
MlirNamedAttribute
name_attribute
(
const
named_attribute_t
&
na
)
const
MlirNamedAttribute
name_attribute
(
const
named_attribute_t
&
na
)
const
...
@@ -455,7 +467,7 @@ struct mlir_program
...
@@ -455,7 +467,7 @@ struct mlir_program
auto
ops
=
create_operation_state
(
"func.func"
);
auto
ops
=
create_operation_state
(
"func.func"
);
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
{
"sym_name"
,
s
td
::
string
(
"main"
)
},
{
"sym_name"
,
s
ym_name
},
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"arch"
,
target_arch
}});
{
"arch"
,
target_arch
}});
ops
.
add_region
(
std
::
move
(
region
));
ops
.
add_region
(
std
::
move
(
region
));
...
@@ -470,13 +482,17 @@ struct mlir_program
...
@@ -470,13 +482,17 @@ struct mlir_program
{
{
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
return
"func.return"
;
return
"func.return"
;
if
(
ins
->
name
()
==
"@literal"
)
{
return
"tosa.const"
;
}
return
"migraphx."
+
ins
->
name
();
return
"migraphx."
+
ins
->
name
();
}
}
static
value
get_operator_value
(
const
operation
&
op
)
static
value
get_operator_value
(
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
if
(
op
.
name
()
==
"convolution"
)
if
(
op
.
name
()
==
"convolution"
or
op
.
name
()
==
"quant_convolution"
)
{
{
// Adjust symetrical padding
// Adjust symetrical padding
if
(
v
.
at
(
"padding"
).
size
()
==
v
.
at
(
"stride"
).
size
())
if
(
v
.
at
(
"padding"
).
size
()
==
v
.
at
(
"stride"
).
size
())
...
@@ -498,31 +514,53 @@ struct mlir_program
...
@@ -498,31 +514,53 @@ struct mlir_program
return
ins
->
get_shape
();
return
ins
->
get_shape
();
}
}
static
std
::
string
get_symbol_name
(
const
module
&
m
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"convolution"
or
ins
->
name
()
==
"dot"
)
{
return
"mlir_"
+
ins
->
name
();
}
}
return
"main"
;
}
void
parse
(
const
module
&
m
)
void
parse
(
const
module
&
m
)
{
{
sym_name
=
get_symbol_name
(
m
);
auto
mbody
=
mlirModuleGetBody
(
mmodule
.
get
());
auto
mbody
=
mlirModuleGetBody
(
mmodule
.
get
());
std
::
unordered_map
<
instruction_ref
,
MlirValue
>
ins_map
;
std
::
unordered_map
<
instruction_ref
,
MlirValue
>
ins_map
;
auto
fbody
=
insert
(
mbody
,
m
,
ins_map
);
auto
fbody
=
insert
(
mbody
,
m
,
ins_map
);
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
==
"@param"
)
if
(
ins
->
name
()
==
"@param"
)
continue
;
continue
;
if
(
ins
->
name
()
==
"contiguous"
)
{
ins_map
[
ins
]
=
ins_map
[
ins
->
inputs
().
at
(
0
)];
continue
;
}
auto
name
=
get_name
(
ins
);
auto
name
=
get_name
(
ins
);
auto
ops
=
create_operation_state
(
name
);
auto
ops
=
create_operation_state
(
name
);
ops
.
add_attribute_value
(
get_operator_value
(
ins
->
get_operator
()));
ops
.
add_attribute_value
(
get_operator_value
(
ins
->
get_operator
()));
if
(
ins
->
name
()
!=
"@return"
)
if
(
ins
->
name
()
!=
"@return"
)
ops
.
add_results
({
get_shape
(
ins
)});
ops
.
add_results
({
get_shape
(
ins
)});
if
(
ins
->
name
()
==
"convolution"
)
if
(
ins
->
name
()
==
"@literal"
)
{
literal
r
=
ins
->
get_literal
();
MlirType
tensor_type
=
make_tensor
(
ins
->
get_shape
());
MlirAttribute
mlir_value_attr
=
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
}
if
(
ins
->
name
()
==
"convolution"
or
ins
->
name
()
==
"dot"
)
{
{
pp
=
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
// check if HW supports xdlops
// check if HW supports xdlops
auto
target_chip
=
trim
(
split_string
(
target_arch
,
':'
).
front
());
if
(
has_xdlops
(
target_arch
))
bool
xdlops
=
contains
(
get_xdlops_archs
(),
target_chip
);
std
::
string
tuned
=
get_tune_params
(
xdlops
);
if
(
not
tuned
.
empty
())
ops
.
add_attributes
({{
"perf_config"
,
tuned
}});
if
(
xdlops
)
ops
.
add_attributes
({{
"xdlopsV2"
,
true
}});
ops
.
add_attributes
({{
"xdlopsV2"
,
true
}});
}
}
...
@@ -542,15 +580,19 @@ struct mlir_program
...
@@ -542,15 +580,19 @@ struct mlir_program
code_object_op
compile
()
MIGRAPHX_TIDY_CONST
code_object_op
compile
()
MIGRAPHX_TIDY_CONST
{
{
mlir_pass_manager
pm
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlir_pass_manager
pm_front
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlir_pass_manager
pm_back
{
mlirPassManagerCreate
(
ctx
.
get
())};
// 1st pipeline to call
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline
(
pm
.
get
());
mlirMIGraphXAddHighLevelPipeline
(
pm_front
.
get
());
mlirPassManagerRun
(
pm_front
.
get
(),
mmodule
.
get
());
// 2nd pipeline to call
// 2nd pipeline to call
mlirMIGraphXAddBackendPipeline
(
pm
.
get
(),
target_arch
.
c_str
());
get_module_tuned
();
mlirPassManagerRun
(
pm
.
get
(),
mmodule
.
get
());
mlirMIGraphXAddBackendPipeline
(
pm_back
.
get
(),
target_arch
.
c_str
());
mlirPassManagerRun
(
pm_back
.
get
(),
mmodule
.
get
());
code_object_op
op
{};
code_object_op
op
{};
op
.
symbol_name
=
"main"
;
op
.
symbol_name
=
sym_name
;
op
.
code_object
=
get_binary
();
op
.
code_object
=
get_binary
();
std
::
tie
(
op
.
global
,
op
.
local
)
=
get_launch_params
();
std
::
tie
(
op
.
global
,
op
.
local
)
=
get_launch_params
();
return
op
;
return
op
;
...
@@ -578,7 +620,74 @@ struct mlir_program
...
@@ -578,7 +620,74 @@ struct mlir_program
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
}
}
std
::
string
get_tune_params
(
bool
xdlops
)
{
return
get_mlir_perf_for_conv
(
pp
,
xdlops
);
}
std
::
string
get_tune_params
(
bool
xdlops
)
const
{
return
get_mlir_perf_for_conv
(
pp
,
xdlops
);
}
// This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts.
void
dump_tuning_cfg
(
const
char
*
prob_config
)
const
{
std
::
string
tuning_cfg_path
=
string_value_of
(
MIGRAPHX_MLIR_TUNING_CFG
{});
if
(
!
tuning_cfg_path
.
empty
())
{
std
::
vector
<
std
::
string
>
tokens
=
split_string
(
prob_config
,
'\t'
);
std
::
string
prob
=
tokens
[
1
];
if
(
starts_with
(
prob
,
"conv"
))
{
tuning_cfg_path
+=
".conv"
;
}
else
{
tuning_cfg_path
+=
".gemm"
;
}
std
::
ofstream
tuning_cfg
(
tuning_cfg_path
,
std
::
ios
::
app
);
tuning_cfg
<<
prob
<<
std
::
endl
;
}
}
static
mlir_tuning_table
create_tuning_table
()
{
mlir_tuning_table
tuning_table
{
mlirRockTuningTableCreate
()};
std
::
string
tuning_db_path
=
string_value_of
(
MIGRAPHX_MLIR_TUNING_DB
{});
if
(
!
tuning_db_path
.
empty
())
{
std
::
ifstream
tuning_db_tsv
(
tuning_db_path
);
if
(
tuning_db_tsv
)
{
std
::
string
line
;
while
(
std
::
getline
(
tuning_db_tsv
,
line
))
{
std
::
vector
<
std
::
string
>
tokens
=
split_string
(
line
,
'\t'
);
std
::
string
arch
=
tokens
[
0
];
std
::
string
prob
=
tokens
[
1
];
std
::
string
perf
=
tokens
[
2
];
std
::
string
key
=
arch
.
append
(
"
\t
"
).
append
(
prob
);
mlirRockTuningUpdateTable
(
tuning_table
.
get
(),
key
.
c_str
(),
perf
.
c_str
(),
1.0
);
}
}
}
else
{
std
::
cerr
<<
"WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance."
<<
std
::
endl
;
}
return
tuning_table
;
}
bool
get_module_tuned
()
const
{
static
mlir_tuning_table
tuning_table
=
create_tuning_table
();
if
(
!
mlirRockTuningSetFromTable
(
tuning_table
.
get
(),
mmodule
.
get
()))
{
const
char
*
prob_config
=
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
());
std
::
stringstream
key
(
prob_config
);
std
::
cerr
<<
"fails to set param on"
<<
prob_config
<<
std
::
endl
;
dump_tuning_cfg
(
prob_config
);
return
false
;
}
return
true
;
}
mlir_context
ctx
;
mlir_context
ctx
;
MlirLocation
location
;
MlirLocation
location
;
...
@@ -586,6 +695,7 @@ struct mlir_program
...
@@ -586,6 +695,7 @@ struct mlir_program
problem_params
pp
;
problem_params
pp
;
std
::
deque
<
std
::
string
>
strings
{};
std
::
deque
<
std
::
string
>
strings
{};
std
::
string
target_arch
;
std
::
string
target_arch
;
std
::
string
sym_name
;
};
};
std
::
string
dump_mlir
(
const
module
&
m
)
std
::
string
dump_mlir
(
const
module
&
m
)
...
@@ -645,12 +755,13 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
...
@@ -645,12 +755,13 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
{
{
adjust_param_shapes
(
m
,
inputs
);
adjust_param_shapes
(
m
,
inputs
);
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
if
(
trace
)
std
::
cout
<<
m
<<
std
::
endl
;
// set mutex while llvm thread support is disabled.
// set mutex while llvm thread support is disabled.
static
std
::
mutex
g_mlirc_mutex
;
// NOLINT
static
std
::
mutex
g_mlirc_mutex
;
// NOLINT
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
g_mlirc_mutex
);
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
g_mlirc_mutex
);
if
(
trace
)
std
::
cout
<<
m
<<
std
::
endl
;
mlir_program
mp
;
mlir_program
mp
;
mp
.
find_target
();
mp
.
find_target
();
mp
.
parse
(
m
);
mp
.
parse
(
m
);
...
...
src/targets/gpu/rocblas.cpp
View file @
cd4ab535
...
@@ -47,32 +47,17 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
...
@@ -47,32 +47,17 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return
rb
;
return
rb
;
}
}
const
std
::
unordered_set
<
std
::
string
>&
get_rocblas_fp32_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
};
return
supported_archs
;
}
bool
get_compute_fp32_flag
()
bool
get_compute_fp32_flag
()
{
{
bool
compute_fp32
=
false
;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
if
(
contains
(
get_rocblas_fp32_archs
(),
device_name
))
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
compute_fp32
=
true
;
#endif
return
compute_fp32
;
}
}
bool
get_int8_x4_format
(
context
&
ctx
)
bool
get_int8_x4_format
(
context
&
ctx
)
{
{
bool
int8_x4_format
=
true
;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags
flag
;
rocblas_gemm_flags
flag
;
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
int8_x4_format
=
(
flag
==
rocblas_gemm_flags_pack_int8x4
);
return
flag
==
rocblas_gemm_flags_pack_int8x4
;
#endif
return
int8_x4_format
;
}
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
cd4ab535
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include <migraphx/check_context.hpp>
#include <migraphx/check_context.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_data_type.hpp>
...
@@ -34,6 +33,7 @@
...
@@ -34,6 +33,7 @@
#include <migraphx/eliminate_layout.hpp>
#include <migraphx/eliminate_layout.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/layout_nhwc.hpp>
...
@@ -41,7 +41,7 @@
...
@@ -41,7 +41,7 @@
#include <migraphx/normalize_ops.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/pro
pagate_constant
.hpp>
#include <migraphx/pro
mote_literals
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_gelu.hpp>
...
@@ -49,15 +49,16 @@
...
@@ -49,15 +49,16 @@
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
...
@@ -73,8 +74,11 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -73,8 +74,11 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_
POINTWIS
E_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_
REDUC
E_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_FAST_GELU
)
struct
id_pass
struct
id_pass
{
{
std
::
string
name
()
const
{
return
"id"
;
}
std
::
string
name
()
const
{
return
"id"
;
}
...
@@ -98,14 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -98,14 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// clang-format off
// clang-format off
return
return
{
{
split_single_dyn_dim
{},
dead_code_elimination
{},
normalize_ops
{},
normalize_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
simplify_qdq
{},
simplify_qdq
{},
rewrite_quantization
{},
enable_pass
(
not
mlir_enabled
(),
rewrite_quantization
{}
)
,
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
simplify_reshapes
{},
simplify_reshapes
{},
...
@@ -119,20 +126,21 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -119,20 +126,21 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
inline_module
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_gelu
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_FAST_GELU
{}),
rewrite_gelu
{}
)
,
optimize_module
{},
optimize_module
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
dead_code_elimination
{},
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
// enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
optimize_module
{},
simplify_reshapes
{},
fuse_pointwise
{},
propagate_constant
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_REDUCE_FUSION
{}),
fuse_reduce
{}),
dead_code_elimination
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_
DIS
ABLE_
POINTWISE_FUSION
{}),
fuse_
pointwise
{}),
enable_pass
(
enabled
(
MIGRAPHX_
EN
ABLE_
CK
{}),
fuse_
ck
{}),
dead_code_elimination
{},
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
enable_pass
(
mlir_enabled
(),
fuse_mlir
{
&
ctx
}
)
,
dead_code_elimination
{},
dead_code_elimination
{},
lowering
{
&
ctx
,
options
.
offload_copy
},
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
eliminate_contiguous
{
"gpu::contiguous"
},
...
@@ -151,7 +159,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -151,7 +159,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
compile_ops
{
&
ctx
,
options
.
exhaustive_tune
},
dead_code_elimination
{},
promote_literals
{},
dead_code_elimination
{},
dead_code_elimination
{},
write_literals
{
&
ctx
},
write_literals
{
&
ctx
},
schedule
{
gpu
::
schedule_model
{
ctx
.
get_current_device
().
nstreams
()},
not
enabled
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
{})},
schedule
{
gpu
::
schedule_model
{
ctx
.
get_current_device
().
nstreams
()},
not
enabled
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
{})},
...
...
src/targets/gpu/
driver/perf
.cpp
→
src/targets/gpu/
time_op
.cpp
View file @
cd4ab535
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/context.hpp>
#include <migraphx/context.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
...
@@ -30,7 +30,6 @@
...
@@ -30,7 +30,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
driver
{
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
{
{
...
@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
...
@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
}
}
}
// namespace driver
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/ref/CMakeLists.txt
View file @
cd4ab535
...
@@ -31,10 +31,9 @@ set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
...
@@ -31,10 +31,9 @@ set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
rocm_set_soversion
(
migraphx_ref
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_ref
${
MIGRAPHX_SO_VERSION
}
)
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
find_package
(
Threads
)
rocm_clang_tidy_check
(
migraphx_ref
)
rocm_clang_tidy_check
(
migraphx_ref
)
target_link_libraries
(
migraphx_ref migraphx
Threads::Threads
)
target_link_libraries
(
migraphx_ref
PUBLIC
migraphx
)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
...
...
src/targets/ref/lowering.cpp
View file @
cd4ab535
...
@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs)
...
@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs)
};
};
}
}
template
<
class
Op
>
struct
ref_convolution
:
auto_register_op
<
ref_convolution
<
Op
>>
{
ref_convolution
()
=
default
;
ref_convolution
(
Op
pop
)
:
op
(
std
::
move
(
pop
))
{}
Op
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"ref::"
+
op
.
name
();
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
std
::
vector
<
std
::
size_t
>
padding
;
if
(
op
.
padding_mode
!=
op
::
padding_mode_t
::
default_
)
{
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
padding
=
op
.
padding_mode
==
op
::
same_upper
?
calc_dyn_auto_pad
(
input_lens
,
weights_lens
,
op
.
stride
,
op
.
dilation
,
true
)
:
calc_dyn_auto_pad
(
input_lens
,
weights_lens
,
op
.
stride
,
op
.
dilation
,
false
);
output_shape
=
compute_padded_shape
(
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
(),
padding
,
op
.
stride
,
op
.
dilation
);
}
else
{
padding
=
op
.
padding
;
if
(
output_shape
.
dynamic
())
{
output_shape
=
op
.
normalize_compute_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()});
}
}
argument
result
{
output_shape
};
visit_quantize
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
auto
in_lens
=
input
.
get_shape
().
lens
();
auto
wei_lens
=
weights
.
get_shape
().
lens
();
auto
wei_n
=
wei_lens
[
0
];
auto
wei_c
=
wei_lens
[
1
];
std
::
vector
<
std
::
size_t
>
win_size
(
wei_lens
.
begin
()
+
1
,
wei_lens
.
end
());
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx_o
=
output_shape
.
multi
(
i
);
auto
w
=
idx_o
[
1
];
auto
n_dim
=
idx_o
.
size
();
std
::
vector
<
std
::
ptrdiff_t
>
win_start
;
for
(
std
::
size_t
dim
=
2
;
dim
<
n_dim
;
++
dim
)
{
auto
d_2
=
dim
-
2
;
win_start
.
push_back
(
std
::
ptrdiff_t
(
idx_o
[
dim
]
*
op
.
stride
[
d_2
])
-
std
::
ptrdiff_t
(
padding
[
d_2
]));
}
const
auto
group_id
=
w
/
(
wei_n
/
op
.
group
);
shape
win_shape
{
output_shape
.
type
(),
win_size
};
double
acc
=
0.0
;
shape_for_each
(
win_shape
,
[
&
](
auto
idx_win
)
{
auto
k
=
idx_win
[
0
];
const
auto
in_ch
=
group_id
*
wei_c
+
k
;
std
::
vector
<
std
::
ptrdiff_t
>
idx
(
idx_o
.
begin
(),
idx_o
.
end
());
idx
[
1
]
=
in_ch
;
std
::
transform
(
idx_win
.
begin
()
+
1
,
idx_win
.
end
(),
win_start
.
begin
(),
idx
.
begin
()
+
2
,
[](
std
::
ptrdiff_t
ii
,
std
::
ptrdiff_t
jj
)
{
return
ii
+
jj
;
});
std
::
vector
<
std
::
ptrdiff_t
>
idx_wei
(
idx_o
.
size
());
idx_wei
[
0
]
=
w
;
std
::
copy
(
idx_win
.
begin
(),
idx_win
.
end
(),
idx_wei
.
begin
()
+
1
);
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
std
::
equal
(
idx
.
begin
(),
idx
.
end
(),
in_lens
.
begin
(),
in_lens
.
end
(),
std
::
less
<
std
::
ptrdiff_t
>
{}))
{
acc
+=
input
(
idx
.
begin
(),
idx
.
end
())
*
weights
(
idx_wei
.
begin
(),
idx_wei
.
end
());
}
});
output
[
i
]
=
acc
;
});
});
return
result
;
}
};
struct
ref_im2col
struct
ref_im2col
{
{
op
::
im2col
op
;
op
::
im2col
op
;
...
@@ -564,11 +461,8 @@ struct ref_apply
...
@@ -564,11 +461,8 @@ struct ref_apply
void
init
()
void
init
()
{
{
apply_map
[
"convolution"
]
=
extend_op
<
ref_convolution
<
op
::
convolution
>
,
op
::
convolution
>
();
apply_map
[
"dot"
]
=
extend_op
<
ref_gemm
,
op
::
dot
>
();
apply_map
[
"dot"
]
=
extend_op
<
ref_gemm
,
op
::
dot
>
();
apply_map
[
"quant_dot"
]
=
extend_op
<
ref_quant_gemm
,
op
::
quant_dot
>
();
apply_map
[
"quant_dot"
]
=
extend_op
<
ref_quant_gemm
,
op
::
quant_dot
>
();
apply_map
[
"quant_convolution"
]
=
extend_op
<
ref_convolution
<
op
::
quant_convolution
>
,
op
::
quant_convolution
>
();
apply_map
[
"im2col"
]
=
extend_op
<
ref_im2col
,
op
::
im2col
>
();
apply_map
[
"im2col"
]
=
extend_op
<
ref_im2col
,
op
::
im2col
>
();
apply_map
[
"logsoftmax"
]
=
extend_op
<
ref_softmax
<
op
::
logsoftmax
>
,
op
::
logsoftmax
>
();
apply_map
[
"logsoftmax"
]
=
extend_op
<
ref_softmax
<
op
::
logsoftmax
>
,
op
::
logsoftmax
>
();
apply_map
[
"lrn"
]
=
extend_op
<
ref_lrn
,
op
::
lrn
>
();
apply_map
[
"lrn"
]
=
extend_op
<
ref_lrn
,
op
::
lrn
>
();
...
...
src/tf/op_parser.cpp
View file @
cd4ab535
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map
().
end
(),
op_parser_map
().
end
(),
std
::
back_inserter
(
result
),
std
::
back_inserter
(
result
),
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
std
::
sort
(
result
.
begin
(),
result
.
end
());
return
result
;
return
result
;
}
}
...
...
src/tf/tf.cpp
View file @
cd4ab535
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/tf/tf_parser.hpp>
#include <migraphx/tf/op_parser.hpp>
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <unordered_map>
#include <unordered_map>
...
@@ -62,5 +63,7 @@ program parse_tf(const std::string& name, const tf_options& options)
...
@@ -62,5 +63,7 @@ program parse_tf(const std::string& name, const tf_options& options)
return
std
::
move
(
parser
.
prog
);
return
std
::
move
(
parser
.
prog
);
}
}
std
::
vector
<
std
::
string
>
get_tf_operators
()
{
return
tf
::
get_op_parsers
();
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/value.cpp
View file @
cd4ab535
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/hash.hpp>
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <utility>
...
@@ -519,6 +520,38 @@ std::ostream& operator<<(std::ostream& os, const value& d)
...
@@ -519,6 +520,38 @@ std::ostream& operator<<(std::ostream& os, const value& d)
return
os
;
return
os
;
}
}
template
<
class
T
>
std
::
size_t
value_hash
(
const
std
::
string
&
key
,
const
T
&
x
)
{
std
::
size_t
h
=
hash_value
(
key
);
hash_combine
(
h
,
x
);
return
h
;
}
std
::
size_t
value_hash
(
const
std
::
string
&
key
,
std
::
nullptr_t
)
{
return
hash_value
(
key
);
}
std
::
size_t
value_hash
(
const
std
::
string
&
key
,
const
std
::
vector
<
value
>&
x
)
{
std
::
size_t
h
=
hash_value
(
key
);
for
(
const
auto
&
v
:
x
)
hash_combine
(
h
,
v
);
return
h
;
}
std
::
size_t
value_hash
(
const
std
::
string
&
key
,
const
value
::
binary
&
x
)
{
std
::
size_t
h
=
hash_value
(
key
);
for
(
const
auto
&
v
:
x
)
hash_combine
(
h
,
v
);
return
h
;
}
std
::
size_t
value
::
hash
()
const
{
std
::
size_t
h
=
0
;
this
->
visit_value
([
&
](
const
auto
&
a
)
{
h
=
value_hash
(
this
->
get_key
(),
a
);
});
return
h
;
}
void
value
::
debug_print
(
bool
show_type
)
const
void
value
::
debug_print
(
bool
show_type
)
const
{
{
if
(
show_type
)
if
(
show_type
)
...
...
test/CMakeLists.txt
View file @
cd4ab535
...
@@ -110,7 +110,7 @@ function(add_test_executable TEST_NAME)
...
@@ -110,7 +110,7 @@ function(add_test_executable TEST_NAME)
add_test_command
(
${
TEST_NAME
}
${
TEST_COMMAND
}
)
add_test_command
(
${
TEST_NAME
}
${
TEST_COMMAND
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx migraphx_onnx
)
target_link_libraries
(
${
TEST_NAME
}
migraphx migraphx_onnx
migraphx_ref
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
endfunction
(
add_test_executable
)
endfunction
(
add_test_executable
)
...
@@ -134,6 +134,9 @@ if(MIGRAPHX_ENABLE_GPU)
...
@@ -134,6 +134,9 @@ if(MIGRAPHX_ENABLE_GPU)
COST 10
COST 10
RESOURCE_LOCK gpu
RESOURCE_LOCK gpu
)
)
if
(
MIGRAPHX_USE_HIPRTC
)
target_compile_definitions
(
test_gpu_
${
BASE_NAME
}
PUBLIC -DMIGRAPHX_USE_HIPRTC
)
endif
()
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraphx_gpu migraphx_kernels
)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraphx_gpu migraphx_kernels
)
endforeach
()
endforeach
()
endif
()
endif
()
...
@@ -163,7 +166,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
...
@@ -163,7 +166,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set
(
TEST_NAME test_
${
BASE_NAME
}
)
set
(
TEST_NAME test_
${
BASE_NAME
}
)
add_executable
(
${
TEST_NAME
}
${
ONNX_TEST
}
)
add_executable
(
${
TEST_NAME
}
${
ONNX_TEST
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
migraphx_ref
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
...
@@ -187,6 +190,25 @@ if(MIGRAPHX_ENABLE_PYTHON)
...
@@ -187,6 +190,25 @@ if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory
(
py
)
add_subdirectory
(
py
)
endif
()
endif
()
# multitarget test
if
(
MIGRAPHX_ENABLE_GPU AND MIGRAPHX_ENABLE_CPU AND MIGRAPHX_ENABLE_FPGA
)
set
(
TEST_MULTI_TARGET_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/multi_target
)
file
(
GLOB MULTI_TARGET_TESTS
${
CONFIGURE_DEPENDS
}
${
TEST_MULTI_TARGET_DIR
}
/*.cpp
)
foreach
(
MULTI_TARGET_TEST
${
MULTI_TARGET_TESTS
}
)
get_filename_component
(
BASE_NAME
${
MULTI_TARGET_TEST
}
NAME_WE
)
set
(
TEST_NAME test_
${
BASE_NAME
}
)
add_executable
(
${
TEST_NAME
}
${
MULTI_TARGET_TEST
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx migraphx_onnx migraphx_tf migraphx_all_targets
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
TEST_MULTI_TARGET_DIR
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
endforeach
()
endif
()
function
(
test_header NAME HEADER
)
function
(
test_header NAME HEADER
)
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/header-main-include-
${
NAME
}
.cpp
"#include <
${
HEADER
}
>
\n
int main() {}
\n
"
"#include <
${
HEADER
}
>
\n
int main() {}
\n
"
...
@@ -218,3 +240,11 @@ test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/r
...
@@ -218,3 +240,11 @@ test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/r
if
(
MIGRAPHX_ENABLE_GPU
)
if
(
MIGRAPHX_ENABLE_GPU
)
test_headers
(
migraphx/gpu
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/*.hpp
)
test_headers
(
migraphx/gpu
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/*.hpp
)
endif
()
endif
()
if
(
MIGRAPHX_ENABLE_CPU
)
test_headers
(
migraphx/cpu
${
CMAKE_SOURCE_DIR
}
/src/targets/cpu/include/migraphx/cpu/*.hpp
)
endif
()
if
(
MIGRAPHX_ENABLE_FPGA
)
test_headers
(
migraphx/fpga
${
CMAKE_SOURCE_DIR
}
/src/targets/fpga/include/migraphx/fpga/*.hpp
)
endif
()
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment