Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
40fbef9b
Unverified
Commit
40fbef9b
authored
Aug 05, 2023
by
Ted Themistokleous
Committed by
GitHub
Aug 05, 2023
Browse files
Merge branch 'develop' into threaded_nms
parents
d164b151
aeb9f78c
Changes
440
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
251 additions
and
121 deletions
+251
-121
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+0
-4
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+5
-15
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+8
-7
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+11
-3
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+18
-12
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+157
-39
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+8
-16
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+17
-8
src/targets/gpu/time_op.cpp
src/targets/gpu/time_op.cpp
+1
-3
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+2
-0
src/targets/ref/include/migraphx/ref/context.hpp
src/targets/ref/include/migraphx/ref/context.hpp
+1
-0
src/targets/ref/include/migraphx/ref/lowering.hpp
src/targets/ref/include/migraphx/ref/lowering.hpp
+2
-2
src/targets/ref/include/migraphx/ref/target.hpp
src/targets/ref/include/migraphx/ref/target.hpp
+1
-1
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+1
-1
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+2
-1
src/tf/op_parser.cpp
src/tf/op_parser.cpp
+1
-0
src/tf/parse_batchnorm.cpp
src/tf/parse_batchnorm.cpp
+5
-6
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
40fbef9b
...
...
@@ -28,10 +28,6 @@
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
40fbef9b
...
...
@@ -130,6 +130,8 @@ struct index
return
blockDim
.
x
;
}
#endif
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
template
<
class
N
,
class
Stride
>
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
{
...
...
@@ -231,6 +233,12 @@ struct index
{
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
}
template
<
class
F
,
class
N
>
__device__
void
group_stride
(
N
n
,
F
f
)
const
{
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
}
};
#ifdef MIGRAPHX_NLOCAL
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
40fbef9b
...
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
//
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sin
,
::
hsin
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
...
...
@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// at this time are: exp2, exp10, log2, log10, isinf
// Most but not all of these math ops have operators of the same names.
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
...
...
@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
//
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
template
<
class
T
,
class
U
>
...
...
@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
fmaxf
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
fminf
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
__hmax
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
__hmin
)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_any_vec
<
T
>())
>
constexpr
auto
max
(
const
T
&
a
,
const
T
&
b
)
...
...
@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
View file @
40fbef9b
...
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template
<
class
...
Ts
>
__device__
void
println
(
Ts
...
xs
)
{
print_each
(
&
cout
ln
,
xs
...);
print_each
(
&
cout
,
xs
...
,
'\n'
);
}
template
<
class
...
Ts
>
__device__
void
println_once
(
Ts
...
xs
)
{
print_each_once
(
&
cout
ln
,
xs
...);
print_each_once
(
&
cout
,
xs
...
,
'\n'
);
}
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
40fbef9b
...
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix
)
\
#define MIGRAPHX_DPP_REDUCE(op, prefix
, sign)
\
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
_u32);
\
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
sign##32);
\
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
)
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
...
...
@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__
void
fused_reduce
(
Output
output
,
F
f
)
{
Algo
::
template
run
<
Reduced
>([
&
](
auto
out_idx
,
auto
r
)
{
auto
result
=
f
(
r
);
auto
result
=
f
(
r
,
out_idx
);
if
constexpr
(
reduce
::
is_inner_storage
<
decltype
(
result
)
>
{})
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
;
})(
output
,
result
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
40fbef9b
...
...
@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
// Note, left shift cannot be used to get the maximum value of int64_type or
// uint64_type because it is undefined behavior to left shift 64 bits for
// these types
if
(
n
==
sizeof
(
int64_t
))
return
-
1
;
return
(
1ul
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
...
...
@@ -228,9 +236,9 @@ constexpr T numeric_max()
if
constexpr
(
is_integral
<
T
>
{})
{
if
constexpr
(
is_unsigned
<
T
>
{})
return
int_max
(
sizeof
(
T
))
*
2
;
else
return
int_max
(
sizeof
(
T
));
else
return
int_max
(
sizeof
(
T
))
/
2
;
}
else
if
constexpr
(
is_same
<
T
,
double
>
{})
return
__DBL_MAX__
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
40fbef9b
...
...
@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return
vec
<
T
,
N
>
{
x
};
else
{
MIGRAPHX_ASSERT
((
i
+
N
)
<
vec_size
<
T
>
());
MIGRAPHX_ASSERT
((
i
+
N
)
<
=
vec_size
<
T
>
());
vec
<
vec_type
<
T
>
,
N
>
result
=
{
0
};
for
(
int
j
=
0
;
j
<
N
;
j
++
)
{
...
...
src/targets/gpu/lowering.cpp
View file @
40fbef9b
...
...
@@ -22,12 +22,19 @@
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/gpu/lowering.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
...
...
@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -53,8 +55,9 @@ namespace gpu {
struct
miopen_apply
{
module
*
mod
=
nullptr
;
const
lowering
*
pass
=
nullptr
;
module
*
mod
=
nullptr
;
module_pass_manager
*
mpm
=
nullptr
;
const
lowering
*
pass
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
instruction_ref
last
{};
bool
offload_copy
=
false
;
...
...
@@ -83,7 +86,7 @@ struct miopen_apply
auto
&
ctx
=
get_context
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
->
name
()
==
"main"
)
?
pass
->
offload_copy
:
false
;
offload_copy
=
(
mod
==
mpm
->
get_root_module
()
)
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
...
...
@@ -103,7 +106,7 @@ struct miopen_apply
add_extend_op
(
"topk"
);
add_convolution_op
(
"convolution"
);
add_convolution_op
(
"
de
convolution"
);
add_convolution_op
(
"convolution
_backwards
"
);
add_convolution_op
(
"quant_convolution"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
...
...
@@ -375,7 +378,10 @@ struct miopen_apply
}
};
void
lowering
::
apply
(
module
&
m
)
const
{
miopen_apply
{
&
m
,
this
}.
apply
();
}
void
lowering
::
apply
(
module_pass_manager
&
mpm
)
const
{
miopen_apply
{
&
mpm
.
get_module
(),
&
mpm
,
this
}.
apply
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/mlir.cpp
View file @
40fbef9b
...
...
@@ -52,6 +52,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <deque>
...
...
@@ -121,7 +122,10 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using
mlir_context
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirContext
,
mlirContextDestroy
);
using
mlir_context
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirContext
,
mlirContextDestroy
);
using
mlir_thread_pool
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirLlvmThreadPool
,
mlirLlvmThreadPoolDestroy
);
using
mlir_dialect_registry
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirDialectRegistry
,
mlirDialectRegistryDestroy
);
using
mlir_module
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirModule
,
mlirModuleDestroy
);
using
mlir_operation
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOperation
,
mlirOperationDestroy
);
using
mlir_op_printing_flags
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirOpPrintingFlags
,
...
...
@@ -131,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using
mlir_pass_manager
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirPassManager
,
mlirPassManagerDestroy
);
using
mlir_tuning_table
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningTable
,
mlirRockTuningTableDestroy
);
using
mlir_tuning_space
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningSpace
,
mlirRockTuningSpaceDestroy
);
using
mlir_tuning_param
=
MIGRAPHX_MANAGE_MLIR_HANDLE
(
MlirRockTuningParam
,
mlirRockTuningParamDestroy
);
std
::
string_view
to_string_view
(
MlirStringRef
s
)
{
return
{
s
.
data
,
s
.
length
};
}
...
...
@@ -164,25 +172,47 @@ std::string mlir_print(F f, T x)
return
ss
.
str
();
}
const
std
::
unordered_set
<
std
::
string
>
&
get_
xdlops_
arch
s
(
)
bool
has_xdlops
(
const
std
::
string
&
tar
get_arch
)
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
}
;
return
supported_archs
;
const
auto
device_name
=
trim
(
split_string
(
target_arch
,
':'
).
front
())
;
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
)
;
}
struct
mlir_program
{
mlir_program
()
:
ctx
(
mlirContextCreate
()),
:
ctx
(
mlirContextCreateWithRegistry
(
get_dialect_registry
().
get
(),
/*threadingEnable=*/
false
)),
location
(
mlirLocationUnknownGet
(
ctx
.
get
())),
mmodule
(
mlirModuleCreateEmpty
(
location
))
{
MlirDialectRegistry
registry
=
mlirDialectRegistryCreate
();
mlirRegisterRocMLIRDialects
(
registry
);
mlirContextAppendDialectRegistry
(
ctx
.
get
(),
registry
);
mlirContextSetThreadPool
(
ctx
.
get
(),
get_thread_pool
().
get
());
mlirContextLoadAllAvailableDialects
(
ctx
.
get
());
mlirDialectRegistryDestroy
(
registry
);
mlirContextSetAllowUnregisteredDialects
(
ctx
.
get
(),
true
/*allow*/
);
}
static
mlir_dialect_registry
&
get_dialect_registry
()
{
static
std
::
once_flag
init_guard
;
static
mlir_dialect_registry
the_registry
;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std
::
call_once
(
init_guard
,
[
&
]()
{
the_registry
=
mlirDialectRegistryCreate
();
mlirRegisterRocMLIRDialects
(
the_registry
.
get
());
mlirRegisterRocMLIRPasses
();
});
return
the_registry
;
}
static
mlir_thread_pool
&
get_thread_pool
()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static
mlir_thread_pool
the_pool
=
mlirLlvmThreadPoolCreate
();
return
the_pool
;
}
MlirType
make_type
(
shape
::
type_t
t
)
const
...
...
@@ -244,8 +274,6 @@ struct mlir_program
MlirAttribute
attribute
(
std
::
int64_t
i
)
const
{
if
(
i
<
0
)
MIGRAPHX_THROW
(
"MLIR cant handle negative values since they are ambiguous"
);
return
mlirIntegerAttrGet
(
mlirIntegerTypeGet
(
ctx
.
get
(),
64
),
i
);
}
MlirAttribute
attribute
(
std
::
uint64_t
i
)
const
...
...
@@ -324,7 +352,8 @@ struct mlir_program
std
::
string
,
value
,
std
::
vector
<
value
>
,
MlirType
>
;
MlirType
,
MlirAttribute
>
;
using
named_attribute_t
=
std
::
pair
<
std
::
string_view
,
attribute_t
>
;
MlirNamedAttribute
name_attribute
(
const
named_attribute_t
&
na
)
const
...
...
@@ -365,14 +394,20 @@ struct mlir_program
mlir_operation_state
&
add_attributes
(
const
std
::
vector
<
named_attribute_t
>&
named_attrs
)
{
auto
attributes
=
prog
->
name_attributes
(
named_attrs
);
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
if
(
not
attributes
.
empty
())
{
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
}
return
*
this
;
}
mlir_operation_state
&
add_attribute_value
(
const
value
&
v
)
{
auto
attributes
=
prog
->
name_attributes
(
v
);
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
if
(
not
attributes
.
empty
())
{
mlirOperationStateAddAttributes
(
&
op_state
,
attributes
.
size
(),
attributes
.
data
());
}
return
*
this
;
}
...
...
@@ -395,13 +430,19 @@ struct mlir_program
return
shape
{
r
.
type
(),
r
.
lens
()};
});
auto
x
=
prog
->
make_tensors
(
reshaped
);
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
if
(
not
x
.
empty
())
{
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
}
return
*
this
;
}
mlir_operation_state
&
add_operands
(
const
std
::
vector
<
MlirValue
>&
inputs
)
{
mlirOperationStateAddOperands
(
&
op_state
,
inputs
.
size
(),
inputs
.
data
());
if
(
not
inputs
.
empty
())
{
mlirOperationStateAddOperands
(
&
op_state
,
inputs
.
size
(),
inputs
.
data
());
}
return
*
this
;
}
...
...
@@ -411,7 +452,10 @@ struct mlir_program
std
::
transform
(
regions
.
begin
(),
regions
.
end
(),
mregions
.
begin
(),
[](
const
auto
&
r
)
{
return
r
.
get
();
});
mlirOperationStateAddOwnedRegions
(
&
op_state
,
mregions
.
size
(),
mregions
.
data
());
if
(
not
mregions
.
empty
())
{
mlirOperationStateAddOwnedRegions
(
&
op_state
,
mregions
.
size
(),
mregions
.
data
());
}
mlir_operation
op
(
mlirOperationCreate
(
&
op_state
));
// Release memory since mlir_operation owns it
for
(
auto
&
r
:
regions
)
...
...
@@ -481,6 +525,10 @@ struct mlir_program
{
if
(
ins
->
name
()
==
"@return"
)
return
"func.return"
;
if
(
ins
->
name
()
==
"@literal"
)
{
return
"tosa.const"
;
}
return
"migraphx."
+
ins
->
name
();
}
...
...
@@ -532,19 +580,30 @@ struct mlir_program
{
if
(
ins
->
name
()
==
"@param"
)
continue
;
if
(
ins
->
name
()
==
"contiguous"
)
{
ins_map
[
ins
]
=
ins_map
[
ins
->
inputs
().
at
(
0
)];
continue
;
}
auto
name
=
get_name
(
ins
);
auto
ops
=
create_operation_state
(
name
);
ops
.
add_attribute_value
(
get_operator_value
(
ins
->
get_operator
()));
if
(
ins
->
name
()
!=
"@return"
)
ops
.
add_results
({
get_shape
(
ins
)});
if
(
ins
->
name
()
==
"@literal"
)
{
literal
r
=
ins
->
get_literal
();
MlirType
tensor_type
=
make_tensor
(
ins
->
get_shape
());
MlirAttribute
mlir_value_attr
=
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
}
if
(
ins
->
name
()
==
"convolution"
or
ins
->
name
()
==
"dot"
)
{
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
// check if HW supports xdlops
auto
target_chip
=
trim
(
split_string
(
target_arch
,
':'
).
front
());
bool
xdlops
=
contains
(
get_xdlops_archs
(),
target_chip
);
if
(
xdlops
)
if
(
has_xdlops
(
target_arch
))
ops
.
add_attributes
({{
"xdlopsV2"
,
true
}});
}
...
...
@@ -562,18 +621,30 @@ struct mlir_program
}
}
code_object_op
compil
e
()
MIGRAPHX_TIDY_CONST
void
run_high_level_pipelin
e
()
MIGRAPHX_TIDY_CONST
{
mlir_pass_manager
pm_front
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlir_pass_manager
pm_back
{
mlirPassManagerCreate
(
ctx
.
get
())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline
(
pm_front
.
get
());
mlirPassManagerRun
(
pm_front
.
get
(),
mmodule
.
get
());
mlirPassManagerRunOnOp
(
pm_front
.
get
(),
mlirModuleGetOperation
(
mmodule
.
get
()));
}
// 2nd pipeline to call
get_module_tuned
();
void
run_backend_pipeline
()
MIGRAPHX_TIDY_CONST
{
mlir_pass_manager
pm_back
{
mlirPassManagerCreate
(
ctx
.
get
())};
mlirMIGraphXAddBackendPipeline
(
pm_back
.
get
(),
target_arch
.
c_str
());
mlirPassManagerRun
(
pm_back
.
get
(),
mmodule
.
get
());
mlirPassManagerRunOnOp
(
pm_back
.
get
(),
mlirModuleGetOperation
(
mmodule
.
get
()));
}
code_object_op
compile
(
const
value
&
solution
)
MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline
();
if
(
solution
.
is_null
())
get_module_tuned
();
else
set_tuning
(
solution
);
// 2nd pipeline to call
run_backend_pipeline
();
code_object_op
op
{};
op
.
symbol_name
=
sym_name
;
...
...
@@ -604,6 +675,33 @@ struct mlir_program
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
}
void
set_tuning
(
const
value
&
v
)
{
auto
str
=
v
.
to
<
std
::
string
>
();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
std
::
vector
<
char
>
buffer
(
str
.
begin
(),
str
.
end
());
buffer
.
push_back
(
0
);
if
(
not
mlirRockTuningSetFromStr
(
mmodule
.
get
(),
buffer
.
data
()))
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
str
);
}
tuning_config
get_tuning_config
()
MIGRAPHX_TIDY_CONST
{
tuning_config
tc
;
run_high_level_pipeline
();
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
())};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParamsFull
(
params
.
get
())))
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
MIGRAPHX_THROW
(
"Incorrect mlir tuning parameter: "
+
std
::
to_string
(
i
));
tc
.
solutions
.
push_back
(
std
::
string
{
mlirRockTuningGetParamStr
(
param
.
get
())});
}
mlir_tuning_table
tuning_table
{
mlirRockTuningTableCreate
()};
tc
.
problem
=
std
::
string
{
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
())};
return
tc
;
}
std
::
string
get_tune_params
(
bool
xdlops
)
const
{
return
get_mlir_perf_for_conv
(
pp
,
xdlops
);
}
// This function appends to tuning cfg file that could be
...
...
@@ -662,6 +760,11 @@ struct mlir_program
bool
get_module_tuned
()
const
{
static
mlir_tuning_table
tuning_table
=
create_tuning_table
();
// The tuning table as currently implemented is currently not
// thread safe. This will be fixed in the future. For now,
// stick a mutex around all tuning table interaction.
static
std
::
mutex
lock
;
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
if
(
!
mlirRockTuningSetFromTable
(
tuning_table
.
get
(),
mmodule
.
get
()))
{
const
char
*
prob_config
=
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
());
...
...
@@ -690,14 +793,14 @@ std::string dump_mlir(const module& m)
return
mlir_print
(
&
mlirOperationPrint
,
mod_op
);
}
void
adjust_param_shapes
(
module
&
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
)
void
adjust_param_shapes
(
module
&
m
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
names
=
m
.
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
auto
i
:
range
(
names
.
size
()))
{
const
auto
&
name
=
names
[
i
];
const
auto
&
input
=
inputs
[
i
]
->
get_shape
()
;
const
auto
&
input
=
inputs
[
i
];
auto
param
=
m
.
get_parameter
(
name
);
if
(
input
.
standard
())
continue
;
...
...
@@ -735,24 +838,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
}
}
code_object_op
compile_mlir
(
const
context
&
,
module
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
)
code_object_op
compile_mlir
(
const
context
&
,
module
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
value
&
solution
)
{
adjust_param_shapes
(
m
,
inputs
);
adjust_param_shapes
(
m
,
to_shapes
(
inputs
)
)
;
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
if
(
trace
)
std
::
cout
<<
m
<<
std
::
endl
;
// set mutex while llvm thread support is disabled.
static
std
::
mutex
g_mlirc_mutex
;
// NOLINT
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
g_mlirc_mutex
);
mlir_program
mp
;
mp
.
find_target
();
mp
.
parse
(
m
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
if
(
trace
)
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
auto
co
=
mp
.
compile
();
co
.
output
=
m
.
get_output_shapes
().
front
();
auto
co
=
mp
.
compile
(
solution
);
co
.
expected_inputs
=
to_shapes
(
inputs
);
co
.
output
=
m
.
get_output_shapes
().
front
();
return
co
;
}
...
...
@@ -772,6 +877,16 @@ instruction_ref insert_mlir(module& m,
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
}
tuning_config
get_tuning_config_mlir
(
module
m
,
const
std
::
vector
<
shape
>&
inputs
)
{
adjust_param_shapes
(
m
,
inputs
);
mlir_program
mp
;
mp
.
find_target
();
mp
.
parse
(
m
);
return
mp
.
get_tuning_config
();
}
#else
std
::
string
dump_mlir
(
const
module
&
)
{
return
{};
}
...
...
@@ -783,11 +898,11 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op
compile_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
instruction_ref
>&
)
code_object_op
compile_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
instruction_ref
>&
,
const
value
&
)
{
return
{};
}
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref
// cppcheck-suppress funcArgNamesDifferent
...
...
@@ -797,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return
m
.
end
();
}
tuning_config
get_tuning_config_mlir
(
module
,
const
std
::
vector
<
shape
>&
)
{
return
{};
}
// NOLINTEND(performance-unnecessary-value-param)
#endif
}
// namespace gpu
...
...
src/targets/gpu/rocblas.cpp
View file @
40fbef9b
...
...
@@ -47,32 +47,24 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return
rb
;
}
const
std
::
unordered_set
<
std
::
string
>&
get_rocblas_fp32_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
};
return
supported_archs
;
}
bool
get_compute_fp32_flag
()
{
bool
compute_fp32
=
false
;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
if
(
contains
(
get_rocblas_fp32_archs
(),
device_name
))
compute_fp32
=
true
;
#endif
return
compute_fp32
;
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
bool
get_int8_x4_format
(
context
&
ctx
)
{
bool
int8_x4_format
=
true
;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
#if ROCBLAS_VERSION_MAJOR >= 3
(
void
)(
ctx
);
return
false
;
#else
// int8x4 packed format is only available starting from rocblas-v2.38 and it is deprecated in
// v3.0 and will be removed in v4.0
rocblas_gemm_flags
flag
;
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
int8_x4_format
=
(
flag
==
rocblas_gemm_flags_pack_int8x4
)
;
return
flag
==
rocblas_gemm_flags_pack_int8x4
;
#endif
return
int8_x4_format
;
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
40fbef9b
...
...
@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
...
...
@@ -72,9 +73,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_REDUCE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
#endif
struct
id_pass
{
std
::
string
name
()
const
{
return
"id"
;
}
...
...
@@ -98,16 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// clang-format off
return
{
enable_pass
(
options
.
split_single_dyn_dim
,
split_single_dyn_dim
{}
)
,
enable_pass
(
options
.
split_single_dyn_dim
,
dead_code_elimination
{}
)
,
split_single_dyn_dim
{},
dead_code_elimination
{},
normalize_ops
{},
dead_code_elimination
{},
simplify_qdq
{},
rewrite_quantization
{},
enable_pass
(
not
mlir_enabled
(),
rewrite_quantization
{}
)
,
dead_code_elimination
{},
eliminate_data_type
{
unsupported_types
,
shape
::
type_t
::
float_type
},
simplify_reshapes
{},
...
...
@@ -121,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
rewrite_pooling
{},
dead_code_elimination
{},
rewrite_gelu
{},
enable_pass
(
options
.
fast_math
,
rewrite_gelu
{}
)
,
optimize_module
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
...
...
@@ -129,11 +134,15 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
auto_contiguous
{},
optimize_module
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}
)
,
fuse_pointwise
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_REDUCE_FUSION
{}),
fuse_reduce
{}),
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
#ifndef _WIN32
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_CK
{}),
fuse_ck
{}),
#endif
dead_code_elimination
{},
enable_pass
(
mlir_enabled
(),
fuse_mlir
{
&
ctx
}),
dead_code_elimination
{},
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
...
...
@@ -150,7 +159,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
adjust_allocation
{
gpu_allocation_model
{}},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
compile_ops
{
&
ctx
,
options
.
exhaustive_tune
},
dead_code_elimination
{},
promote_literals
{},
dead_code_elimination
{},
...
...
src/targets/gpu/
driver/perf
.cpp
→
src/targets/gpu/
time_op
.cpp
View file @
40fbef9b
...
...
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/
driver/perf
.hpp>
#include <migraphx/gpu/
time_op
.hpp>
#include <migraphx/context.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
...
...
@@ -30,7 +30,6 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
{
...
...
@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
}
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/ref/CMakeLists.txt
View file @
40fbef9b
...
...
@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
migraphx_generate_export_header
(
migraphx_ref
)
rocm_install_targets
(
TARGETS migraphx_ref
INCLUDE
...
...
src/targets/ref/include/migraphx/ref/context.hpp
View file @
40fbef9b
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp>
#include <migraphx/ref/export.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/ref/include/migraphx/ref/lowering.hpp
View file @
40fbef9b
...
...
@@ -24,14 +24,14 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#include <migraphx/ref/context.hpp>
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
ref
{
struct
lowering
struct
MIGRAPHX_REF_EXPORT
lowering
{
std
::
string
name
()
const
{
return
"ref::lowering"
;
}
void
apply
(
module
&
m
)
const
;
...
...
src/targets/ref/include/migraphx/ref/target.hpp
View file @
40fbef9b
...
...
@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
pass
;
namespace
ref
{
struct
target
struct
MIGRAPHX_REF_EXPORT
target
{
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
ctx
,
const
compile_options
&
)
const
;
...
...
src/targets/ref/lowering.cpp
View file @
40fbef9b
...
...
@@ -27,7 +27,7 @@
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/
de
convolution.hpp>
#include <migraphx/op/convolution
_backwards
.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
...
...
src/tf/CMakeLists.txt
View file @
40fbef9b
...
...
@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
file
(
GLOB TF_SRCS
${
CONFIGURE_DEPENDS
}
*.cpp
)
file
(
GLOB TF_SRCS CONFIGURE_DEPENDS *.cpp
)
add_library
(
migraphx_tf
${
TF_SRCS
}
)
migraphx_generate_export_header
(
migraphx_tf
)
target_include_directories
(
migraphx_tf PRIVATE include
)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
...
...
src/tf/op_parser.cpp
View file @
40fbef9b
...
...
@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map
().
end
(),
std
::
back_inserter
(
result
),
[
&
](
auto
&&
p
)
{
return
p
.
first
;
});
std
::
sort
(
result
.
begin
(),
result
.
end
());
return
result
;
}
...
...
src/tf/parse_batchnorm.cpp
View file @
40fbef9b
...
...
@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
// unsqueeze tensors of shape (C) to broadcast correctly
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
scale_unsqueeze
=
...
...
@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto
var_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
,
2
}}}),
args
[
4
]);
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
denom
=
info
.
add_
broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"
div"
,
numer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
scale_unsqueeze
);
auto
x_sub_mean
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
rsqrt
=
info
.
add_
instruction
(
make_op
(
"rsqrt"
)
,
var_eps
);
auto
mul0
=
info
.
add_broadcastable_binary_op
(
"
mul"
,
scale_unsqueeze
,
rsqrt
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
x_sub_mean
,
mul0
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
bias_unsqueeze
);
}
};
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment