Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5a14c0bf
Commit
5a14c0bf
authored
Oct 19, 2022
by
umangyadav
Browse files
Merge branch 'develop' into workspace_size
parents
cb01e280
5fa42993
Changes
319
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
214 additions
and
193 deletions
+214
-193
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+67
-1
src/include/migraphx/execution_environment.hpp
src/include/migraphx/execution_environment.hpp
+7
-8
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+2
-2
src/include/migraphx/op/batch_norm_inference.hpp
src/include/migraphx/op/batch_norm_inference.hpp
+0
-70
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+2
-2
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+4
-13
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-1
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+1
-1
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+3
-5
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-1
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+15
-11
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-2
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+14
-4
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+0
-48
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+29
-5
src/include/migraphx/value.hpp
src/include/migraphx/value.hpp
+6
-0
src/module.cpp
src/module.cpp
+6
-2
src/onnx/parse_batchnorm.cpp
src/onnx/parse_batchnorm.cpp
+50
-14
src/onnx/parse_convolution.cpp
src/onnx/parse_convolution.cpp
+0
-2
src/onnx/parse_deconvolution.cpp
src/onnx/parse_deconvolution.cpp
+5
-1
No files found.
src/include/migraphx/context.hpp
View file @
5a14c0bf
...
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
return
{};
}
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
{
}
template
<
class
T
>
void
finish_on_context
(
T
&
,
any_ptr
)
{
}
#ifdef TYPE_ERASED_DECLARATION
...
...
@@ -78,6 +87,10 @@ struct context
void
from_value
(
const
value
&
v
);
// (optional)
any_ptr
get_queue
();
// (optional)
void
wait_for
(
any_ptr
queue
);
// (optional)
void
finish_on
(
any_ptr
queue
);
//
void
finish
()
const
;
};
...
...
@@ -165,6 +178,18 @@ struct context
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -187,6 +212,8 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
wait_for
(
any_ptr
queue
)
=
0
;
virtual
void
finish_on
(
any_ptr
queue
)
=
0
;
virtual
void
finish
()
const
=
0
;
};
...
...
@@ -231,6 +258,33 @@ struct context
return
get_queue_context
(
private_detail_te_self
);
}
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
...
...
@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
)
)
:
private_detail_te_value
(
value
)
{
}
...
...
@@ -277,6 +331,18 @@ struct context
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
src/
targets/gpu/
include/migraphx/
gpu/add
.hpp
→
src/include/migraphx/
execution_environment
.hpp
View file @
5a14c0bf
...
...
@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_ADD
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_ADD
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/any_ptr.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
hip_add
:
binary_device
<
hip_add
,
device
::
add
>
struct
execution_environment
{
any_ptr
queue
=
any_ptr
{};
bool
async
=
false
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
#endif
/* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
src/include/migraphx/match/layernorm.hpp
View file @
5a14c0bf
...
...
@@ -50,8 +50,8 @@ struct layernorm_matcher
{
return
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
has_value
(
1e-12
f
))))))));
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
))))))));
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/include/migraphx/op/batch_norm_inference.hpp
deleted
100644 → 0
View file @
cb01e280
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <cmath>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
batch_norm_inference
{
float
epsilon
=
1.0e-6
f
;
float
momentum
=
0.9
f
;
std
::
string
name
()
const
{
return
"batch_norm_inference"
;
}
enum
bn_infer_mode_t
{
per_activation
,
spatial
,
};
bn_infer_mode_t
bn_mode
=
spatial
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
momentum
,
"momentum"
),
f
(
self
.
bn_mode
,
"bn_mode"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
.
data
(),
inputs
.
data
()
+
1
,
*
this
}.
same_ndims
();
check_shapes
{
inputs
.
data
()
+
1
,
inputs
.
data
()
+
inputs
.
size
(),
*
this
}.
same_shape
();
return
inputs
.
front
();
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/common.hpp
View file @
5a14c0bf
...
...
@@ -33,11 +33,11 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum
padding_mode_t
{
default_
,
// NOLINT
same
,
valid
,
same_lower
,
same_upper
};
...
...
src/include/migraphx/op/convolution.hpp
View file @
5a14c0bf
...
...
@@ -41,9 +41,8 @@ struct convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
bool
use_dynamic_same_auto_pad
=
false
;
int
group
=
1
;
padding_mode_t
padding_mode
=
default_
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -52,8 +51,7 @@ struct convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
f
(
self
.
padding_mode
,
"padding_mode"
));
}
std
::
string
name
()
const
{
return
"convolution"
;
}
...
...
@@ -93,13 +91,6 @@ struct convolution
x_shape
.
lens
().
at
(
1
)
!=
(
w_shape
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: mismatched channel numbers"
);
std
::
vector
<
op
::
padding_mode_t
>
dyn_pad_modes
=
{
op
::
padding_mode_t
::
same_upper
,
op
::
padding_mode_t
::
same_lower
};
if
(
use_dynamic_same_auto_pad
and
not
contains
(
dyn_pad_modes
,
padding_mode
))
{
MIGRAPHX_THROW
(
"CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode"
);
}
if
(
x_shape
.
dynamic
()
or
w_shape
.
dynamic
())
{
return
dynamic_compute_shape
(
x_shape
,
w_shape
);
...
...
@@ -161,7 +152,7 @@ struct convolution
dynamic_shape_push_back
(
w_shape
);
const
size_t
num_spatial_dims
=
x_shape
.
max_lens
().
size
()
-
2
;
if
(
use_dynamic_same_auto_pad
)
if
(
padding_mode
!=
default_
)
{
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
...
...
src/include/migraphx/op/fmod.hpp
View file @
5a14c0bf
...
...
@@ -40,7 +40,6 @@ struct fmod : binary<fmod>
a
[
"commutative"
]
=
false
;
return
a
;
}
std
::
string
point_function
()
const
{
return
"fmod"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
(
x
,
y
);
};
...
...
src/include/migraphx/op/mod.hpp
View file @
5a14c0bf
...
...
@@ -38,9 +38,9 @@ struct mod : binary<mod>
{
auto
a
=
base_attributes
();
a
[
"commutative"
]
=
false
;
a
[
"point_op"
]
=
"${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})"
;
return
a
;
}
std
::
string
point_function
()
const
{
return
"mod"
;
}
auto
apply
()
const
{
return
[](
auto
x
,
auto
y
)
{
return
std
::
fmod
((
std
::
remainder
(
x
,
y
))
+
y
,
y
);
};
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
5a14c0bf
...
...
@@ -41,9 +41,8 @@ struct quant_convolution
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
dilation
=
{
1
,
1
};
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
bool
use_dynamic_same_auto_pad
=
false
;
padding_mode_t
padding_mode
=
default_
;
int
group
=
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -52,8 +51,7 @@ struct quant_convolution
f
(
self
.
stride
,
"stride"
),
f
(
self
.
dilation
,
"dilation"
),
f
(
self
.
padding_mode
,
"padding_mode"
),
f
(
self
.
group
,
"group"
),
f
(
self
.
use_dynamic_same_auto_pad
,
"use_dynamic_same_auto_pad"
));
f
(
self
.
group
,
"group"
));
}
value
attributes
()
const
...
...
src/include/migraphx/operators.hpp
View file @
5a14c0bf
...
...
@@ -35,7 +35,6 @@
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
...
...
src/include/migraphx/pad_calc.hpp
View file @
5a14c0bf
...
...
@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx,
/*!
* Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
std
::
vector
<
std
::
size_t
>
tensor_lens
,
std
::
vector
<
std
::
size_t
>
k_lens
,
std
::
vector
<
std
::
size_t
>
strides
,
std
::
vector
<
std
::
size_t
>
dilations
,
bool
use_upper
=
true
);
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
wei_lens
,
const
std
::
vector
<
std
::
size_t
>&
strides
,
const
std
::
vector
<
std
::
size_t
>&
dilations
,
bool
use_upper
);
// Used for dynamic auto padding of convolution operators since padding needs to be computed at
// evaulation time.
shape
compute_padded_shape
(
const
shape
&
input
,
const
shape
&
weights
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
stride
,
const
std
::
vector
<
std
::
size_t
>&
dilation
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/program.hpp
View file @
5a14c0bf
...
...
@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm>
#include <iostream>
...
...
@@ -76,8 +77,8 @@ struct program
std
::
unordered_map
<
std
::
string
,
shape
>
get_parameter_shapes
()
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
)
const
;
std
::
vector
<
argument
>
eval
(
parameter_map
params
,
execution_environment
exec_env
=
execution_environment
{})
const
;
std
::
size_t
size
()
const
;
std
::
vector
<
shape
>
get_output_shapes
()
const
;
...
...
src/include/migraphx/reflect.hpp
View file @
5a14c0bf
...
...
@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
}
template
<
class
T
>
auto
reflectable_impl
(
rank
<
1
>
,
T
&
&
x
)
auto
reflectable_impl
(
rank
<
1
>
,
const
T
&
x
)
->
decltype
(
T
::
reflect
(
x
,
reflect_placeholder
{}),
std
::
true_type
{});
template
<
class
T
>
auto
reflectable_impl
(
rank
<
0
>
,
T
&
&
)
->
decltype
(
std
::
false_type
{});
auto
reflectable_impl
(
rank
<
0
>
,
const
T
&
)
->
decltype
(
std
::
false_type
{});
template
<
class
T
>
struct
remove_rvalue_reference
...
...
@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template
<
class
T
>
auto
reflect_tie
(
T
&
x
)
{
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
return
detail
::
wrap
<
decltype
(
y
)
>
(
y
);
})(
[](
auto
&&
...
xs
)
{
return
detail
::
auto_tuple
(
xs
.
get
()...);
});
return
reflect
(
x
,
[](
auto
&&
y
,
auto
&&
...)
{
// cppcheck-suppress UnnecessaryElseStatement
if
constexpr
(
is_reflectable
<
decltype
(
y
)
>
{})
{
auto
t
=
reflect_tie
(
y
);
return
detail
::
wrap
<
decltype
(
t
)
>
(
t
);
}
else
{
return
detail
::
wrap
<
decltype
(
y
)
>
(
y
);
}
})([](
auto
&&
...
xs
)
{
return
detail
::
auto_tuple
(
xs
.
get
()...);
});
}
template
<
class
T
,
class
F
>
...
...
src/include/migraphx/rewrite_batchnorm.hpp
deleted
100644 → 0
View file @
cb01e280
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
* Rewrite batchnorm to a multiply and add.
*/
struct
rewrite_batchnorm
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/streamutils.hpp
View file @
5a14c0bf
...
...
@@ -26,8 +26,11 @@
#include <ostream>
#include <algorithm>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -59,10 +62,22 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace
detail
{
inline
void
stream_write_value_impl
(
rank
<
2
>
,
std
::
ostream
&
os
,
const
std
::
string
&
x
)
{
os
<<
x
;
}
template
<
class
T
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
T
&
x
)
->
decltype
(
os
<<
x
,
void
())
{
os
<<
x
;
}
template
<
class
T
>
void
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
r
)
{
os
<<
"{"
;
os
<<
stream_range
(
r
);
os
<<
"}"
;
}
template
<
class
Range
>
auto
stream_write_value_impl
(
rank
<
1
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
auto
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
Range
&
r
)
->
decltype
(
r
.
begin
(),
r
.
end
(),
void
())
{
os
<<
"{"
;
...
...
@@ -70,17 +85,26 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
os
<<
"}"
;
}
template
<
class
T
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_reflectable
<
T
>{})
>
void
stream_write_value_impl
(
rank
<
0
>
,
std
::
ostream
&
os
,
const
T
&
x
)
{
os
<<
x
;
char
delim
=
'{'
;
reflect_each
(
x
,
[
&
](
auto
&&
y
,
auto
name
)
{
os
<<
delim
;
os
<<
name
<<
"="
;
stream_write_value_impl
(
rank
<
2
>
{},
os
,
y
);
delim
=
','
;
});
if
(
delim
==
','
)
os
<<
"}"
;
}
}
// namespace detail
template
<
class
T
>
void
stream_write_value
(
std
::
ostream
&
os
,
const
T
&
x
)
{
detail
::
stream_write_value_impl
(
rank
<
2
>
{},
os
,
x
);
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/value.hpp
View file @
5a14c0bf
...
...
@@ -184,6 +184,12 @@ struct value
{
}
explicit
binary
(
std
::
size_t
s
)
:
base
(
s
)
{}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
binary
&
obj
)
{
os
<<
"{binary_object: "
<<
obj
.
size
()
<<
"}"
;
return
os
;
}
};
value
()
=
default
;
...
...
src/module.cpp
View file @
5a14c0bf
...
...
@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref
module
::
move_instructions
(
instruction_ref
src
,
instruction_ref
dst
)
{
this
->
move_instruction
(
src
,
dst
);
for
(
auto
ins
:
src
->
inputs
())
this
->
move_instruction
(
ins
,
src
);
{
if
(
not
contains
(
this
->
impl
->
instructions
,
ins
))
continue
;
this
->
move_instructions
(
ins
,
dst
);
}
this
->
move_instruction
(
src
,
dst
);
return
src
;
}
...
...
src/onnx/parse_batchnorm.cpp
View file @
5a14c0bf
...
...
@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
op/batch_norm_inference
.hpp>
#include <migraphx/
instruction
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -36,28 +36,64 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>
&
args
)
const
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
float
epsilon
=
1e-5
f
;
float
momentum
=
0.9
f
;
op
::
batch_norm_inference
::
bn_infer_mode_t
bn_mode
=
op
::
batch_norm_inference
::
spatial
;
float
epsilon
=
1e-5
f
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
if
(
contains
(
info
.
attributes
,
"momentum"
))
auto
x_lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
x_type
=
args
[
0
]
->
get_shape
().
type
();
if
(
std
::
any_of
(
args
.
cbegin
()
+
1
,
args
.
cend
(),
[](
auto
a
)
{
return
a
->
get_shape
().
lens
().
size
()
!=
1
;
}))
{
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"
);
}
auto
x_rank
=
x_lens
.
size
();
if
(
x_rank
==
1
or
x_rank
==
2
)
{
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
args
[
3
]);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
args
[
4
],
eps
);
auto
denom
=
info
.
add_broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
numer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
args
[
1
]);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
args
[
2
]);
}
else
if
(
x_rank
>
2
)
{
momentum
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
// unsqueeze tensors of shape (C) to broadcast correctly
std
::
vector
<
int64_t
>
unsqueeze_axes
(
x_lens
.
size
()
-
2
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
1
);
auto
rt
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
0.5
}});
auto
eps
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_type
},
{
epsilon
}});
auto
scale_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
1
]);
auto
bias_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
2
]);
auto
mean_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
3
]);
auto
var_unsqueeze
=
info
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
args
[
4
]);
auto
numer
=
info
.
add_broadcastable_binary_op
(
"sub"
,
args
[
0
],
mean_unsqueeze
);
auto
var_eps
=
info
.
add_broadcastable_binary_op
(
"add"
,
var_unsqueeze
,
eps
);
auto
denom
=
info
.
add_broadcastable_binary_op
(
"pow"
,
var_eps
,
rt
);
auto
div0
=
info
.
add_broadcastable_binary_op
(
"div"
,
numer
,
denom
);
auto
r0
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div0
,
scale_unsqueeze
);
return
info
.
add_broadcastable_binary_op
(
"add"
,
r0
,
bias_unsqueeze
);
}
if
(
contains
(
info
.
attributes
,
"spatial"
))
else
{
bn_mode
=
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"spatial"
)).
at
<
uint64_t
>
()
>
0
)
?
op
::
batch_norm_inference
::
spatial
:
op
::
batch_norm_inference
::
per_activation
;
// rank ==
0
MIGRAPHX_THROW
(
"PARSE_BATCHNORM: rank "
+
std
::
to_string
(
x_lens
.
size
())
+
" input tensor, unhandled data format"
)
;
}
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
};
return
info
.
add_instruction
(
op
,
args
);
}
};
...
...
src/onnx/parse_convolution.cpp
View file @
5a14c0bf
...
...
@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
values
[
"use_dynamic_same_auto_pad"
]
=
true
;
}
else
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto
weight_lens
=
weights
->
get_shape
().
max_lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
weight_lens
.
begin
()
+
2
,
weight_lens
.
end
());
...
...
src/onnx/parse_deconvolution.cpp
View file @
5a14c0bf
...
...
@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes
(
kdims
,
values
[
"dilation"
].
size
(),
"PARSE_CONV_TRANSPOSE: inconsistent dilations"
);
}
// TODO: auto padding needs to be implemented for this parser and operator
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
auto
s
=
info
.
attributes
[
"auto_pad"
].
s
();
...
...
@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
values
[
"padding_mode"
]
=
to_value
(
op
::
padding_mode_t
::
same
);
bool
is_same_upper
=
(
s
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
}
...
...
Prev
1
2
3
4
5
6
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment