Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
40fbef9b
"sgl-kernel/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "af6535e7aaf5c1e9352149f0edfde37d977cd473"
Unverified
Commit
40fbef9b
authored
Aug 05, 2023
by
Ted Themistokleous
Committed by
GitHub
Aug 05, 2023
Browse files
Merge branch 'develop' into threaded_nms
parents
d164b151
aeb9f78c
Changes
440
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
251 additions
and
121 deletions
+251
-121
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+0
-4
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+5
-15
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+8
-7
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+11
-3
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+18
-12
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+157
-39
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+8
-16
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+17
-8
src/targets/gpu/time_op.cpp
src/targets/gpu/time_op.cpp
+1
-3
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+2
-0
src/targets/ref/include/migraphx/ref/context.hpp
src/targets/ref/include/migraphx/ref/context.hpp
+1
-0
src/targets/ref/include/migraphx/ref/lowering.hpp
src/targets/ref/include/migraphx/ref/lowering.hpp
+2
-2
src/targets/ref/include/migraphx/ref/target.hpp
src/targets/ref/include/migraphx/ref/target.hpp
+1
-1
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+1
-1
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+2
-1
src/tf/op_parser.cpp
src/tf/op_parser.cpp
+1
-0
src/tf/parse_batchnorm.cpp
src/tf/parse_batchnorm.cpp
+5
-6
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
40fbef9b
...
@@ -28,10 +28,6 @@
...
@@ -28,10 +28,6 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
40fbef9b
...
@@ -130,6 +130,8 @@ struct index
...
@@ -130,6 +130,8 @@ struct index
return
blockDim
.
x
;
return
blockDim
.
x
;
}
}
#endif
#endif
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
template
<
class
N
,
class
Stride
>
template
<
class
N
,
class
Stride
>
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
{
{
...
@@ -231,6 +233,12 @@ struct index
...
@@ -231,6 +233,12 @@ struct index
{
{
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
}
}
template
<
class
F
,
class
N
>
__device__
void
group_stride
(
N
n
,
F
f
)
const
{
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
}
};
};
#ifdef MIGRAPHX_NLOCAL
#ifdef MIGRAPHX_NLOCAL
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
40fbef9b
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
//
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sin
,
::
hsin
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
// Use float to compute half overload
...
@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
...
@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// Most but not all of these math ops have operators of the same names.
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
...
@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
...
@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
//
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
...
@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
...
@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
__hmax
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
fmaxf
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
__hmin
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
fminf
)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
...
@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b)
...
@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
View file @
40fbef9b
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__device__
void
println
(
Ts
...
xs
)
__device__
void
println
(
Ts
...
xs
)
{
{
print_each
(
&
cout
ln
,
xs
...);
print_each
(
&
cout
,
xs
...
,
'\n'
);
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__device__
void
println_once
(
Ts
...
xs
)
__device__
void
println_once
(
Ts
...
xs
)
{
{
print_each_once
(
&
cout
ln
,
xs
...);
print_each_once
(
&
cout
,
xs
...
,
'\n'
);
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
40fbef9b
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix
)
\
#define MIGRAPHX_DPP_REDUCE(op, prefix
, sign)
\
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
_u32);
\
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
sign##32);
\
} \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
)
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
...
@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
...
@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__
void
fused_reduce
(
Output
output
,
F
f
)
__device__
void
fused_reduce
(
Output
output
,
F
f
)
{
{
Algo
::
template
run
<
Reduced
>([
&
](
auto
out_idx
,
auto
r
)
{
Algo
::
template
run
<
Reduced
>([
&
](
auto
out_idx
,
auto
r
)
{
auto
result
=
f
(
r
);
auto
result
=
f
(
r
,
out_idx
);
if
constexpr
(
reduce
::
is_inner_storage
<
decltype
(
result
)
>
{})
if
constexpr
(
reduce
::
is_inner_storage
<
decltype
(
result
)
>
{})
{
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
;
})(
output
,
result
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
;
})(
output
,
result
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
40fbef9b
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
// Note, left shift cannot be used to get the maximum value of int64_type or
// uint64_type because it is undefined behavior to left shift 64 bits for
// these types
if
(
n
==
sizeof
(
int64_t
))
return
-
1
;
return
(
1ul
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
,
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
{
{
if
constexpr
(
is_unsigned
<
T
>
{})
if
constexpr
(
is_unsigned
<
T
>
{})
return
int_max
(
sizeof
(
T
))
*
2
;
else
return
int_max
(
sizeof
(
T
));
return
int_max
(
sizeof
(
T
));
else
return
int_max
(
sizeof
(
T
))
/
2
;
}
}
else
if
constexpr
(
is_same
<
T
,
double
>
{})
else
if
constexpr
(
is_same
<
T
,
double
>
{})
return
__DBL_MAX__
;
return
__DBL_MAX__
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
40fbef9b
...
@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
...
@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return
vec
<
T
,
N
>
{
x
};
return
vec
<
T
,
N
>
{
x
};
else
else
{
{
MIGRAPHX_ASSERT
((
i
+
N
)
<
vec_size
<
T
>
());
MIGRAPHX_ASSERT
((
i
+
N
)
<
=
vec_size
<
T
>
());
vec
<
vec_type
<
T
>
,
N
>
result
=
{
0
};
vec
<
vec_type
<
T
>
,
N
>
result
=
{
0
};
for
(
int
j
=
0
;
j
<
N
;
j
++
)
for
(
int
j
=
0
;
j
<
N
;
j
++
)
{
{
...
...
src/targets/gpu/lowering.cpp
View file @
40fbef9b
...
@@ -22,12 +22,19 @@
...
@@ -22,12 +22,19 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <iterator>
#include <iterator>
#include <migraphx/gpu/lowering.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
...
@@ -35,17 +42,12 @@
...
@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -53,8 +55,9 @@ namespace gpu {
...
@@ -53,8 +55,9 @@ namespace gpu {
struct
miopen_apply
struct
miopen_apply
{
{
module
*
mod
=
nullptr
;
module
*
mod
=
nullptr
;
const
lowering
*
pass
=
nullptr
;
module_pass_manager
*
mpm
=
nullptr
;
const
lowering
*
pass
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
instruction_ref
last
{};
instruction_ref
last
{};
bool
offload_copy
=
false
;
bool
offload_copy
=
false
;
...
@@ -83,7 +86,7 @@ struct miopen_apply
...
@@ -83,7 +86,7 @@ struct miopen_apply
auto
&
ctx
=
get_context
();
auto
&
ctx
=
get_context
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
int8_x4_format
=
get_int8_x4_format
(
ctx
);
compute_fp32
=
get_compute_fp32_flag
();
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
offload_copy
=
(
mod
==
mpm
->
get_root_module
()
)
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
add_generic_op
(
"contiguous"
);
...
@@ -103,7 +106,7 @@ struct miopen_apply
...
@@ -103,7 +106,7 @@ struct miopen_apply
add_extend_op
(
"topk"
);
add_extend_op
(
"topk"
);
add_convolution_op
(
"convolution"
);
add_convolution_op
(
"convolution"
);
add_convolution_op
(
"
de
convolution"
);
add_convolution_op
(
"convolution
_backwards
"
);
add_convolution_op
(
"quant_convolution"
);
add_convolution_op
(
"quant_convolution"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
...
@@ -375,7 +378,10 @@ struct miopen_apply
...
@@ -375,7 +378,10 @@ struct miopen_apply
}
}
};
};
void
lowering
::
apply
(
module
&
m
)
const
{
miopen_apply
{
&
m
,
this
}.
apply
();
}
void
lowering
::
apply
(
module_pass_manager
&
mpm
)
const
{
miopen_apply
{
&
mpm
.
get_module
(),
&
mpm
,
this
}.
apply
();
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/mlir.cpp
View file @
40fbef9b
...
@@ -52,6 +52,7 @@
...
@@ -52,6 +52,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <deque>
#include <deque>
...
@@ -121,7 +122,10 @@ struct mlir_handle
...
@@ -121,7 +122,10 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using
mlir_context
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirContext
,
mlirContextDestroy
);
using
mlir_context
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirContext
,
mlirContextDestroy
);
using
mlir_thread_pool
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirLlvmThreadPool
,
mlirLlvmThreadPoolDestroy
);
using
mlir_dialect_registry
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirDialectRegistry
,
mlirDialectRegistryDestroy
);
using
mlir_module
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirModule
,
mlirModuleDestroy
);
using
mlir_module
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirModule
,
mlirModuleDestroy
);
using
mlir_operation
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOperation
,
mlirOperationDestroy
);
using
mlir_operation
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOperation
,
mlirOperationDestroy
);
using
mlir_op_printing_flags
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOpPrintingFlags
,
using
mlir_op_printing_flags
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOpPrintingFlags
,
...
@@ -131,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
...
@@ -131,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using
mlir_pass_manager
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirPassManager
,
mlirPassManagerDestroy
);
using
mlir_pass_manager
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirPassManager
,
mlirPassManagerDestroy
);
using
mlir_tuning_table
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningTable
,
using
mlir_tuning_table
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningTable
,
mlirRockTuningTableDestroy
);
mlirRockTuningTableDestroy
);
using
mlir_tuning_space
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningSpace
,
mlirRockTuningSpaceDestroy
);
using
mlir_tuning_param
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningParam
,
mlirRockTuningParamDestroy
);
std
::
string_view
to_string_view
(
MlirStringRef
s
)
{
return
{
s
.
data
,
s
.
length
};
}
std
::
string_view
to_string_view
(
MlirStringRef
s
)
{
return
{
s
.
data
,
s
.
length
};
}
...
@@ -164,25 +172,47 @@ std::string mlir_print(F f, T x)
...
@@ -164,25 +172,47 @@ std::string mlir_print(F f, T x)
return
ss
.
str
();
return
ss
.
str
();
}
}
const
std
::
unordered_set
<
std
::
string
>
&
get_
xdlops_
arch
s
(
)
bool
has_xdlops
(
const
std
::
string
&
tar
get_arch
)
{
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
}
;
const
auto
device_name
=
trim
(
split_string
(
target_arch
,
':'
).
front
())
;
return
supported_archs
;
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
)
;
}
}
struct
mlir_program
struct
mlir_program
{
{
mlir_program
()
mlir_program
()
:
ctx
(
mlirContextCreate
()),
:
ctx
(
mlirContextCreateWithRegistry
(
get_dialect_registry
().
get
(),
/*threadingEnable=*/
false
)),
location
(
mlirLocationUnknownGet
(
ctx
.
get
())),
location
(
mlirLocationUnknownGet
(
ctx
.
get
())),
mmodule
(
mlirModuleCreateEmpty
(
location
))
mmodule
(
mlirModuleCreateEmpty
(
location
))
{
{
MlirDialectRegistry
registry
=
mlirDialectRegistryCreate
();
mlirContextSetThreadPool
(
ctx
.
get
(),
get_thread_pool
().
get
());
mlirRegisterRocMLIRDialects
(
registry
);
mlirContextAppendDialectRegistry
(
ctx
.
get
(),
registry
);
mlirContextLoadAllAvailableDialects
(
ctx
.
get
());
mlirContextLoadAllAvailableDialects
(
ctx
.
get
());
mlirDialectRegistryDestroy
(
registry
);
}
mlirContextSetAllowUnregisteredDialects
(
ctx
.
get
(),
true
/*allow*/
);
static
mlir_dialect_registry
&
get_dialect_registry
()
{
static
std
::
once_flag
init_guard
;
static
mlir_dialect_registry
the_registry
;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std
::
call_once
(
init_guard
,
[
&
]()
{
the_registry
=
mlirDialectRegistryCreate
();
mlirRegisterRocMLIRDialects
(
the_registry
.
get
());
mlirRegisterRocMLIRPasses
();
});
return
the_registry
;
}
static
mlir_thread_pool
&
get_thread_pool
()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static
mlir_thread_pool
the_pool
=
mlirLlvmThreadPoolCreate
();
return
the_pool
;
}
}
MlirType
make_type
(
shape
::
type_t
t
)
const
MlirType
make_type
(
shape
::
type_t
t
)
const
...
@@ -244,8 +274,6 @@ struct mlir_program
...
@@ -244,8 +274,6 @@ struct mlir_program
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
{
{
if
(
i
<
0
)
MIGRAPHX_THROW
(
"MLIR cant handle negative values since they are ambiguous"
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
}
}
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
...
@@ -324,7 +352,8 @@ struct mlir_program
...
@@ -324,7 +352,8 @@ struct mlir_program
std
::
string
,
std
::
string
,
value
,
value
,
std
::
vector
<
value
>
,
std
::
vector
<
value
>
,
MlirType
>
;
MlirType
,
MlirAttribute
>
;
using
named_attribute_t
=
std
::
pair
<
std
::
string_view
,
attribute_t
>
;
using
named_attribute_t
=
std
::
pair
<
std
::
string_view
,
attribute_t
>
;
MlirNamedAttribute
name_attribute
(
const
named_attribute_t
&
na
)
const
MlirNamedAttribute
name_attribute
(
const
named_attribute_t
&
na
)
const
...
@@ -365,14 +394,20 @@ struct mlir_program
...
@@ -365,14 +394,20 @@ struct mlir_program
mlir_operation_state
&
add_attributes
(
const
std
::
vector
<
named_attribute_t
>&
named_attrs
)
mlir_operation_state
&
add_attributes
(
const
std
::
vector
<
named_attribute_t
>&
named_attrs
)
{
{
auto
attributes
=
prog
->
name_attributes
(
named_attrs
);
auto
attributes
=
prog
->
name_attributes
(
named_attrs
);
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
if
(
not
attributes
.
empty
())
{
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
}
return
*
this
;
return
*
this
;
}
}
mlir_operation_state
&
add_attribute_value
(
const
value
&
v
)
mlir_operation_state
&
add_attribute_value
(
const
value
&
v
)
{
{
auto
attributes
=
prog
->
name_attributes
(
v
);
auto
attributes
=
prog
->
name_attributes
(
v
);
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
if
(
not
attributes
.
empty
())
{
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
}
return
*
this
;
return
*
this
;
}
}
...
@@ -395,13 +430,19 @@ struct mlir_program
...
@@ -395,13 +430,19 @@ struct mlir_program
return
shape
{
r
.
type
(),
r
.
lens
()};
return
shape
{
r
.
type
(),
r
.
lens
()};
});
});
auto
x
=
prog
->
make_tensors
(
reshaped
);
auto
x
=
prog
->
make_tensors
(
reshaped
);
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
if
(
not
x
.
empty
())
{
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
}
return
*
this
;
return
*
this
;
}
}
mlir_operation_state
&
add_operands
(
const
std
::
vector
<
MlirValue
>&
inputs
)
mlir_operation_state
&
add_operands
(
const
std
::
vector
<
MlirValue
>&
inputs
)
{
{
mlirOperationStateAddOperands
(
&
op_state
,
inputs
.
size
(),
inputs
.
data
());
if
(
not
inputs
.
empty
())
{
mlirOperationStateAddOperands
(
&
op_state
,
inputs
.
size
(),
inputs
.
data
());
}
return
*
this
;
return
*
this
;
}
}
...
@@ -411,7 +452,10 @@ struct mlir_program
...
@@ -411,7 +452,10 @@ struct mlir_program
std
::
transform
(
regions
.
begin
(),
regions
.
end
(),
mregions
.
begin
(),
[](
const
auto
&
r
)
{
std
::
transform
(
regions
.
begin
(),
regions
.
end
(),
mregions
.
begin
(),
[](
const
auto
&
r
)
{
return
r
.
get
();
return
r
.
get
();
});
});
mlirOperationStateAddOwnedRegions
(
&
op_state
,
mregions
.
size
(),
mregions
.
data
());
if
(
not
mregions
.
empty
())
{
mlirOperationStateAddOwnedRegions
(
&
op_state
,
mregions
.
size
(),
mregions
.
data
());
}
mlir_operation
op
(
mlirOperationCreate
(
&
op_state
));
mlir_operation
op
(
mlirOperationCreate
(
&
op_state
));
// Release memory since mlir_operation owns it
// Release memory since mlir_operation owns it
for
(
auto
&
r
:
regions
)
for
(
auto
&
r
:
regions
)
...
@@ -481,6 +525,10 @@ struct mlir_program
...
@@ -481,6 +525,10 @@ struct mlir_program
{
{
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
return
"func.return"
;
return
"func.return"
;
if
(
ins
->
name
()
==
"@literal"
)
{
return
"tosa.const"
;
}
return
"migraphx."
+
ins
->
name
();
return
"migraphx."
+
ins
->
name
();
}
}
...
@@ -532,19 +580,30 @@ struct mlir_program
...
@@ -532,19 +580,30 @@ struct mlir_program
{
{
if
(
ins
->
name
()
==
"@param"
)
if
(
ins
->
name
()
==
"@param"
)
continue
;
continue
;
if
(
ins
->
name
()
==
"contiguous"
)
{
ins_map
[
ins
]
=
ins_map
[
ins
->
inputs
().
at
(
0
)];
continue
;
}
auto
name
=
get_name
(
ins
);
auto
name
=
get_name
(
ins
);
auto
ops
=
create_operation_state
(
name
);
auto
ops
=
create_operation_state
(
name
);
ops
.
add_attribute_value
(
get_operator_value
(
ins
->
get_operator
()));
ops
.
add_attribute_value
(
get_operator_value
(
ins
->
get_operator
()));
if
(
ins
->
name
()
!=
"@return"
)
if
(
ins
->
name
()
!=
"@return"
)
ops
.
add_results
({
get_shape
(
ins
)});
ops
.
add_results
({
get_shape
(
ins
)});
if
(
ins
->
name
()
==
"@literal"
)
{
literal
r
=
ins
->
get_literal
();
MlirType
tensor_type
=
make_tensor
(
ins
->
get_shape
());
MlirAttribute
mlir_value_attr
=
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
}
if
(
ins
->
name
()
==
"convolution"
or
ins
->
name
()
==
"dot"
)
if
(
ins
->
name
()
==
"convolution"
or
ins
->
name
()
==
"dot"
)
{
{
pp
=
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
// check if HW supports xdlops
// check if HW supports xdlops
auto
target_chip
=
trim
(
split_string
(
target_arch
,
':'
).
front
());
if
(
has_xdlops
(
target_arch
))
bool
xdlops
=
contains
(
get_xdlops_archs
(),
target_chip
);
if
(
xdlops
)
ops
.
add_attributes
({{
"xdlopsV2"
,
true
}});
ops
.
add_attributes
({{
"xdlopsV2"
,
true
}});
}
}
...
@@ -562,18 +621,30 @@ struct mlir_program
...
@@ -562,18 +621,30 @@ struct mlir_program
}
}
}
}
code_object_op
compil
e
()
MIGRAPHX_TIDY_CONST
void
run_high_level_pipelin
e
()
MIGRAPHX_TIDY_CONST
{
{
mlir_pass_manager
pm_front
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlir_pass_manager
pm_front
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlir_pass_manager
pm_back
{
mlirPassManagerCreate
(
ctx
.
get
())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline
(
pm_front
.
get
());
mlirMIGraphXAddHighLevelPipeline
(
pm_front
.
get
());
mlirPassManagerRun
(
pm_front
.
get
(),
mmodule
.
get
());
mlirPassManagerRunOnOp
(
pm_front
.
get
(),
mlirModuleGetOperation
(
mmodule
.
get
()));
}
// 2nd pipeline to call
void
run_backend_pipeline
()
MIGRAPHX_TIDY_CONST
get_module_tuned
();
{
mlir_pass_manager
pm_back
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlirMIGraphXAddBackendPipeline
(
pm_back
.
get
(),
target_arch
.
c_str
());
mlirMIGraphXAddBackendPipeline
(
pm_back
.
get
(),
target_arch
.
c_str
());
mlirPassManagerRun
(
pm_back
.
get
(),
mmodule
.
get
());
mlirPassManagerRunOnOp
(
pm_back
.
get
(),
mlirModuleGetOperation
(
mmodule
.
get
()));
}
code_object_op
compile
(
const
value
&
solution
)
MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline
();
if
(
solution
.
is_null
())
get_module_tuned
();
else
set_tuning
(
solution
);
// 2nd pipeline to call
run_backend_pipeline
();
code_object_op
op
{};
code_object_op
op
{};
op
.
symbol_name
=
sym_name
;
op
.
symbol_name
=
sym_name
;
...
@@ -604,6 +675,33 @@ struct mlir_program
...
@@ -604,6 +675,33 @@ struct mlir_program
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
}
}
void
set_tuning
(
const
value
&
v
)
{
auto
str
=
v
.
to
<
std
::
string
>
();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
std
::
vector
<
char
>
buffer
(
str
.
begin
(),
str
.
end
());
buffer
.
push_back
(
0
);
if
(
not
mlirRockTuningSetFromStr
(
mmodule
.
get
(),
buffer
.
data
()))
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
str
);
}
tuning_config
get_tuning_config
()
MIGRAPHX_TIDY_CONST
{
tuning_config
tc
;
run_high_level_pipeline
();
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
())};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParamsFull
(
params
.
get
())))
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
MIGRAPHX_THROW
(
"Incorrect mlir tuning parameter: "
+
std
::
to_string
(
i
));
tc
.
solutions
.
push_back
(
std
::
string
{
mlirRockTuningGetParamStr
(
param
.
get
())});
}
mlir_tuning_table
tuning_table
{
mlirRockTuningTableCreate
()};
tc
.
problem
=
std
::
string
{
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
())};
return
tc
;
}
std
::
string
get_tune_params
(
bool
xdlops
)
const
{
return
get_mlir_perf_for_conv
(
pp
,
xdlops
);
}
std
::
string
get_tune_params
(
bool
xdlops
)
const
{
return
get_mlir_perf_for_conv
(
pp
,
xdlops
);
}
// This function appends to tuning cfg file that could be
// This function appends to tuning cfg file that could be
...
@@ -662,6 +760,11 @@ struct mlir_program
...
@@ -662,6 +760,11 @@ struct mlir_program
bool
get_module_tuned
()
const
bool
get_module_tuned
()
const
{
{
static
mlir_tuning_table
tuning_table
=
create_tuning_table
();
static
mlir_tuning_table
tuning_table
=
create_tuning_table
();
// The tuning table as currently implemented is currently not
// thread safe. This will be fixed in the future. For now,
// stick a mutex around all tuning table interaction.
static
std
::
mutex
lock
;
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
if
(
!
mlirRockTuningSetFromTable
(
tuning_table
.
get
(),
mmodule
.
get
()))
if
(
!
mlirRockTuningSetFromTable
(
tuning_table
.
get
(),
mmodule
.
get
()))
{
{
const
char
*
prob_config
=
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
());
const
char
*
prob_config
=
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
());
...
@@ -690,14 +793,14 @@ std::string dump_mlir(const module& m)
...
@@ -690,14 +793,14 @@ std::string dump_mlir(const module& m)
return
mlir_print
(
&
mlirOperationPrint
,
mod_op
);
return
mlir_print
(
&
mlirOperationPrint
,
mod_op
);
}
}
void
adjust_param_shapes
(
module
&
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
)
void
adjust_param_shapes
(
module
&
m
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
auto
names
=
m
.
get_parameter_names
();
auto
names
=
m
.
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
auto
i
:
range
(
names
.
size
()))
for
(
auto
i
:
range
(
names
.
size
()))
{
{
const
auto
&
name
=
names
[
i
];
const
auto
&
name
=
names
[
i
];
const
auto
&
input
=
inputs
[
i
]
->
get_shape
()
;
const
auto
&
input
=
inputs
[
i
];
auto
param
=
m
.
get_parameter
(
name
);
auto
param
=
m
.
get_parameter
(
name
);
if
(
input
.
standard
())
if
(
input
.
standard
())
continue
;
continue
;
...
@@ -735,24 +838,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
...
@@ -735,24 +838,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
}
}
}
}
code_object_op
compile_mlir
(
const
context
&
,
module
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
)
code_object_op
compile_mlir
(
const
context
&
,
module
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
value
&
solution
)
{
{
adjust_param_shapes
(
m
,
inputs
);
adjust_param_shapes
(
m
,
to_shapes
(
inputs
)
)
;
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
if
(
trace
)
if
(
trace
)
std
::
cout
<<
m
<<
std
::
endl
;
std
::
cout
<<
m
<<
std
::
endl
;
// set mutex while llvm thread support is disabled.
static
std
::
mutex
g_mlirc_mutex
;
// NOLINT
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
g_mlirc_mutex
);
mlir_program
mp
;
mlir_program
mp
;
mp
.
find_target
();
mp
.
find_target
();
mp
.
parse
(
m
);
mp
.
parse
(
m
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
if
(
trace
)
if
(
trace
)
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
auto
co
=
mp
.
compile
();
auto
co
=
mp
.
compile
(
solution
);
co
.
output
=
m
.
get_output_shapes
().
front
();
co
.
expected_inputs
=
to_shapes
(
inputs
);
co
.
output
=
m
.
get_output_shapes
().
front
();
return
co
;
return
co
;
}
}
...
@@ -772,6 +877,16 @@ instruction_ref insert_mlir(module& m,
...
@@ -772,6 +877,16 @@ instruction_ref insert_mlir(module& m,
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
}
}
tuning_config
get_tuning_config_mlir
(
module
m
,
const
std
::
vector
<
shape
>&
inputs
)
{
adjust_param_shapes
(
m
,
inputs
);
mlir_program
mp
;
mp
.
find_target
();
mp
.
parse
(
m
);
return
mp
.
get_tuning_config
();
}
#else
#else
std
::
string
dump_mlir
(
const
module
&
)
{
return
{};
}
std
::
string
dump_mlir
(
const
module
&
)
{
return
{};
}
...
@@ -783,11 +898,11 @@ void use(T&)
...
@@ -783,11 +898,11 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage.
// Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param)
// NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op
compile_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
instruction_ref
>&
)
code_object_op
compile_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
instruction_ref
>&
,
const
value
&
)
{
{
return
{};
return
{};
}
}
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref
instruction_ref
// cppcheck-suppress funcArgNamesDifferent
// cppcheck-suppress funcArgNamesDifferent
...
@@ -797,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
...
@@ -797,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return
m
.
end
();
return
m
.
end
();
}
}
tuning_config
get_tuning_config_mlir
(
module
,
const
std
::
vector
<
shape
>&
)
{
return
{};
}
// NOLINTEND(performance-unnecessary-value-param)
#endif
#endif
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/rocblas.cpp
View file @
40fbef9b
...
@@ -47,32 +47,24 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
...
@@ -47,32 +47,24 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return
rb
;
return
rb
;
}
}
const
std
::
unordered_set
<
std
::
string
>&
get_rocblas_fp32_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
};
return
supported_archs
;
}
bool
get_compute_fp32_flag
()
bool
get_compute_fp32_flag
()
{
{
bool
compute_fp32
=
false
;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
if
(
contains
(
get_rocblas_fp32_archs
(),
device_name
))
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
compute_fp32
=
true
;
#endif
return
compute_fp32
;
}
}
bool
get_int8_x4_format
(
context
&
ctx
)
bool
get_int8_x4_format
(
context
&
ctx
)
{
{
bool
int8_x4_format
=
true
;
#if ROCBLAS_VERSION_MAJOR >= 3
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
(
void
)(
ctx
);
return
false
;
#else
// int8x4 packed format is only available starting from rocblas-v2.38 and it is deprecated in
// v3.0 and will be removed in v4.0
rocblas_gemm_flags
flag
;
rocblas_gemm_flags
flag
;
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
int8_x4_format
=
(
flag
==
rocblas_gemm_flags_pack_int8x4
)
;
return
flag
==
rocblas_gemm_flags_pack_int8x4
;
#endif
#endif
return
int8_x4_format
;
}
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
40fbef9b
...
@@ -57,6 +57,7 @@
...
@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
...
@@ -72,9 +73,12 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -72,9 +73,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_REDUCE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_REDUCE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
#endif
struct
id_pass
struct
id_pass
{
{
std
::
string
name
()
const
{
return
"id"
;
}
std
::
string
name
()
const
{
return
"id"
;
}
...
@@ -98,16 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -98,16 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// clang-format off
// clang-format off
return
return
{
{
enable_pass
(
options
.
split_single_dyn_dim
,
split_single_dyn_dim
{}
)
,
split_single_dyn_dim
{},
enable_pass
(
options
.
split_single_dyn_dim
,
dead_code_elimination
{}
)
,
dead_code_elimination
{},
normalize_ops
{},
normalize_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
simplify_qdq
{},
simplify_qdq
{},
rewrite_quantization
{},
enable_pass
(
not
mlir_enabled
(),
rewrite_quantization
{}
)
,
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
simplify_reshapes
{},
simplify_reshapes
{},
...
@@ -121,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -121,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
inline_module
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_gelu
{},
enable_pass
(
options
.
fast_math
,
rewrite_gelu
{}
)
,
optimize_module
{},
optimize_module
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
dead_code_elimination
{},
...
@@ -129,11 +134,15 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -129,11 +134,15 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
optimize_module
{},
optimize_module
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}
)
,
fuse_pointwise
{},
dead_code_elimination
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_REDUCE_FUSION
{}),
fuse_reduce
{}),
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_REDUCE_FUSION
{}),
fuse_reduce
{}),
dead_code_elimination
{},
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
#ifndef _WIN32
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_CK
{}),
fuse_ck
{}),
#endif
dead_code_elimination
{},
enable_pass
(
mlir_enabled
(),
fuse_mlir
{
&
ctx
}),
dead_code_elimination
{},
dead_code_elimination
{},
lowering
{
&
ctx
,
options
.
offload_copy
},
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
eliminate_contiguous
{
"gpu::contiguous"
},
...
@@ -150,7 +159,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -150,7 +159,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
compile_ops
{
&
ctx
,
options
.
exhaustive_tune
},
dead_code_elimination
{},
dead_code_elimination
{},
promote_literals
{},
promote_literals
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
src/targets/gpu/
driver/perf
.cpp
→
src/targets/gpu/
time_op
.cpp
View file @
40fbef9b
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/context.hpp>
#include <migraphx/context.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
...
@@ -30,7 +30,6 @@
...
@@ -30,7 +30,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
driver
{
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
{
{
...
@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
...
@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
}
}
}
// namespace driver
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/ref/CMakeLists.txt
View file @
40fbef9b
...
@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx)
...
@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
migraphx_generate_export_header
(
migraphx_ref
)
rocm_install_targets
(
rocm_install_targets
(
TARGETS migraphx_ref
TARGETS migraphx_ref
INCLUDE
INCLUDE
...
...
src/targets/ref/include/migraphx/ref/context.hpp
View file @
40fbef9b
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ref/export.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/ref/include/migraphx/ref/lowering.hpp
View file @
40fbef9b
...
@@ -24,14 +24,14 @@
...
@@ -24,14 +24,14 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#include <migraphx/ref/context.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
ref
{
namespace
ref
{
struct
lowering
struct
MIGRAPHX_REF_EXPORT
lowering
{
{
std
::
string
name
()
const
{
return
"ref::lowering"
;
}
std
::
string
name
()
const
{
return
"ref::lowering"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
...
...
src/targets/ref/include/migraphx/ref/target.hpp
View file @
40fbef9b
...
@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
pass
;
struct
pass
;
namespace
ref
{
namespace
ref
{
struct
target
struct
MIGRAPHX_REF_EXPORT
target
{
{
std
::
string
name
()
const
;
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
ctx
,
const
compile_options
&
)
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
ctx
,
const
compile_options
&
)
const
;
...
...
src/targets/ref/lowering.cpp
View file @
40fbef9b
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/
de
convolution.hpp>
#include <migraphx/op/convolution
_backwards
.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/quant_dot.hpp>
...
...
src/tf/CMakeLists.txt
View file @
40fbef9b
...
@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w)
...
@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
file
(
GLOB TF_SRCS
${
CONFIGURE_DEPENDS
}
*.cpp
)
file
(
GLOB TF_SRCS CONFIGURE_DEPENDS *.cpp
)
add_library
(
migraphx_tf
${
TF_SRCS
}
)
add_library
(
migraphx_tf
${
TF_SRCS
}
)
migraphx_generate_export_header
(
migraphx_tf
)
target_include_directories
(
migraphx_tf PRIVATE include
)
target_include_directories
(
migraphx_tf PRIVATE include
)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
...
...
src/tf/op_parser.cpp
View file @
40fbef9b
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map
().
end
(),
op_parser_map
().
end
(),
std
::
back_inserter
(
result
),
std
::
back_inserter
(
result
),
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
std
::
sort
(
result
.
begin
(),
result
.
end
());
return
result
;
return
result
;
}
}
...
...
src/tf/parse_batchnorm.cpp
View file @
40fbef9b
...
@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
...
@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
// unsqueeze tensors of shape (C) to broadcast correctly
// unsqueeze tensors of shape (C) to broadcast correctly
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
scale_unsqueeze
=
auto
scale_unsqueeze
=
...
@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
...
@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto
var_unsqueeze
=
auto
var_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
args
[
4
]);
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
args
[
4
]);
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
x_sub_mean
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
denom
=
info
.
add_
broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
rsqrt
=
info
.
add_
instruction
(
make_op
(
"rsqrt"
)
,
var_eps
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"
div"
,
numer
,
denom
);
auto
mul0
=
info
.
add_broadcastable_binary_op
(
"
mul"
,
scale_unsqueeze
,
rsqrt
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
scale_unsqueeze
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
x_sub_mean
,
mul0
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
bias_unsqueeze
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
bias_unsqueeze
);
}
}
};
};
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment