Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cd4ab535
Commit
cd4ab535
authored
Jun 20, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
3891ee58
a0fa3742
Changes
279
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
286 additions
and
100 deletions
+286
-100
src/include/migraphx/hash.hpp
src/include/migraphx/hash.hpp
+47
-0
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+5
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+87
-34
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+2
-0
src/include/migraphx/msgpack.hpp
src/include/migraphx/msgpack.hpp
+2
-0
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+1
-1
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+1
-1
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+10
-4
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-1
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+11
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+50
-7
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+14
-1
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+4
-8
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+1
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+14
-10
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+2
-2
src/include/migraphx/op/pointwise.hpp
src/include/migraphx/op/pointwise.hpp
+5
-4
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+15
-23
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+13
-0
No files found.
src/include/migraphx/hash.hpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
std
::
size_t
hash_value
(
const
T
&
v
)
{
return
std
::
hash
<
T
>
{}(
v
);
}
template
<
class
T
>
void
hash_combine
(
std
::
size_t
&
seed
,
const
T
&
v
)
{
seed
^=
hash_value
(
v
)
+
0x9e3779b9
+
(
seed
<<
6u
)
+
(
seed
>>
2u
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
src/include/migraphx/instruction.hpp
View file @
cd4ab535
...
@@ -136,6 +136,9 @@ struct instruction
...
@@ -136,6 +136,9 @@ struct instruction
operation
normalized_operator
()
const
;
operation
normalized_operator
()
const
;
std
::
size_t
get_target_id
()
const
;
void
set_target_id
(
std
::
size_t
tid
);
void
debug_print
()
const
;
void
debug_print
()
const
;
static
void
print
(
std
::
ostream
&
os
,
static
void
print
(
std
::
ostream
&
os
,
...
@@ -172,7 +175,8 @@ struct instruction
...
@@ -172,7 +175,8 @@ struct instruction
std
::
vector
<
instruction_ref
>
arguments
;
std
::
vector
<
instruction_ref
>
arguments
;
std
::
vector
<
module_ref
>
module_args
;
std
::
vector
<
module_ref
>
module_args
;
literal
lit
;
literal
lit
;
bool
normalized
=
false
;
bool
normalized
=
false
;
std
::
size_t
target_id
=
0
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
cd4ab535
...
@@ -35,6 +35,10 @@
...
@@ -35,6 +35,10 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#ifndef MIGRAPHX_USE_TYPE_ERASED_MATCHERS
#define MIGRAPHX_USE_TYPE_ERASED_MATCHERS 0
#endif
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -103,6 +107,13 @@ struct predicate_matcher
...
@@ -103,6 +107,13 @@ struct predicate_matcher
}
}
};
};
/// Convert a predicate function into a matcher
template
<
class
P
>
predicate_matcher
<
P
>
make_predicate_matcher
(
P
p
)
{
return
{
p
};
}
/// Convert a function into a matcher
/// Convert a function into a matcher
template
<
class
F
>
template
<
class
F
>
struct
function_matcher
struct
function_matcher
...
@@ -124,14 +135,14 @@ template <class M>
...
@@ -124,14 +135,14 @@ template <class M>
auto
bind_match
(
M
m
,
std
::
string
name
)
auto
bind_match
(
M
m
,
std
::
string
name
)
{
{
return
make_function_matcher
(
return
make_function_matcher
(
[
=
,
name
=
std
::
move
(
name
)](
matcher_context
&
ctx
,
[
=
,
m_
name
=
std
::
move
(
name
)](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
)
if
(
result
)
{
{
if
(
not
ctx
.
has_instruction
(
ins
))
if
(
not
ctx
.
has_instruction
(
ins
))
return
nullopt
;
return
nullopt
;
ctx
.
instructions
[
name
]
=
ins
;
ctx
.
instructions
[
m_
name
]
=
ins
;
}
}
return
result
;
return
result
;
});
});
...
@@ -183,14 +194,26 @@ struct id_matcher
...
@@ -183,14 +194,26 @@ struct id_matcher
template
<
class
M
>
template
<
class
M
>
struct
basic_matcher
;
struct
basic_matcher
;
struct
any_matcher
;
template
<
class
M
>
struct
type_erased_matcher
{
#if MIGRAPHX_USE_TYPE_ERASED_MATCHERS
using
type
=
any_matcher
;
#else
using
type
=
basic_matcher
<
M
>
;
#endif
};
template
<
class
M
>
template
<
class
M
>
basic
_matcher
<
M
>
make_basic_matcher
(
M
m
);
typename
type_erased
_matcher
<
M
>
::
type
make_basic_matcher
(
M
m
);
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
auto
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
auto
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
template
<
class
M
>
...
@@ -222,38 +245,38 @@ struct basic_matcher
...
@@ -222,38 +245,38 @@ struct basic_matcher
auto
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
auto
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
};
};
/// Create a typed-erased matcher
using
any_matcher_base
=
basic_matcher
<
function_matcher
<
std
::
function
<
optional
<
instruction_ref
>
(
matcher_context
&
,
instruction_ref
)
>>>
;
struct
any_matcher
:
any_matcher_base
{
template
<
class
M
>
any_matcher
(
M
mm
)
:
any_matcher_base
({[
=
](
auto
&
ctx
,
auto
ins
)
{
return
mm
.
match
(
ctx
,
ins
);
}})
{
}
};
/// Create a basic matcher from a matcher
/// Create a basic matcher from a matcher
template
<
class
M
>
template
<
class
M
>
basic
_matcher
<
M
>
make_basic_matcher
(
M
m
)
typename
type_erased
_matcher
<
M
>
::
type
make_basic_matcher
(
M
m
)
{
{
return
{
m
};
return
{
m
};
}
}
/// Create a basic matcher from a function
/// Create a basic matcher from a function
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
)
auto
make_basic_fun_matcher
(
F
f
)
{
{
return
{{
f
}}
;
return
make_basic_matcher
(
make_function_matcher
(
f
))
;
}
}
/// Create a basic matcher from a predicate function
/// Create a basic matcher from a predicate function
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
)
auto
make_basic_pred_matcher
(
P
p
)
{
{
return
{{
p
}}
;
return
make_basic_matcher
(
make_predicate_matcher
(
p
))
;
}
}
/// Create a typed-erased matcher
using
any_matcher_base
=
basic_matcher
<
function_matcher
<
std
::
function
<
optional
<
instruction_ref
>
(
matcher_context
&
,
instruction_ref
)
>>>
;
struct
any_matcher
:
any_matcher_base
{
template
<
class
M
>
any_matcher
(
M
mm
)
:
any_matcher_base
({[
=
](
auto
&
ctx
,
auto
ins
)
{
return
mm
.
match
(
ctx
,
ins
);
}})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
struct name##_m \
...
@@ -347,31 +370,49 @@ match::matcher_result find_match(module& modl, M&& m)
...
@@ -347,31 +370,49 @@ match::matcher_result find_match(module& modl, M&& m)
}
}
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VALIDATE_MATCHES
)
/// Find matches for an instruction in the module
/// Find matches for an instruction in the module
for per section of matchers
template
<
class
Mod
,
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
Mod
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
void
find_matches
(
size_t
trace_pass
,
Mod
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
const
#endif
#endif
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
bool
match
=
false
;
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
if
(
match
)
if
(
match
)
return
;
return
;
if
(
trace
>
1
)
if
(
trace
>
1
or
trace_pass
>
1
)
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
if
(
r
.
result
==
get_module
(
mod
).
end
())
if
(
r
.
result
==
get_module
(
mod
).
end
())
return
;
return
;
if
(
trace
>
0
)
if
(
trace
>
0
or
trace_pass
>
0
)
{
{
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
get_module
(
mod
).
debug_print
(
ins
);
}
}
// If its already invalid dont validate it again
bool
invalidated
=
validate
and
get_module
(
mod
).
validate
()
!=
get_module
(
mod
).
end
();
m
.
apply
(
mod
,
r
);
m
.
apply
(
mod
,
r
);
if
(
validate
and
not
invalidated
)
{
auto
invalid
=
get_module
(
mod
).
validate
();
if
(
invalid
!=
get_module
(
mod
).
end
())
{
std
::
cout
<<
"Invalid program from match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
);
}
}
match
=
true
;
match
=
true
;
},
},
ms
...);
ms
...);
...
@@ -383,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms)
...
@@ -383,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms)
{
{
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
{
{
find_matches
(
mod
,
ins
,
ms
...);
find_matches
(
0
,
mod
,
ins
,
ms
...);
}
}
/// Find matches in a pass
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
size_t
trace_pass
,
Mod
&
mod
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
{
find_matches
(
trace_pass
,
mod
,
ins
,
ms
...);
}
}
}
}
...
@@ -520,6 +571,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
...
@@ -520,6 +571,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
{
return
not
ins
->
get_shape
().
standard
();
return
not
ins
->
get_shape
().
standard
();
}
}
MIGRAPHX_PRED_MATCHER
(
dynamic_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
static_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
{
return
ins
->
get_shape
().
broadcasted
();
return
ins
->
get_shape
().
broadcasted
();
...
@@ -612,9 +665,9 @@ auto skip_output(Ms... ms)
...
@@ -612,9 +665,9 @@ auto skip_output(Ms... ms)
inline
auto
var
(
std
::
string
s
)
inline
auto
var
(
std
::
string
s
)
{
{
return
make_basic_fun_matcher
(
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
[
=
,
m_
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
auto
it
=
ctx
.
instructions
.
find
(
m_
s
);
if
(
it
==
ctx
.
instructions
.
end
())
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
nullopt
;
return
it
->
second
;
return
it
->
second
;
...
@@ -624,7 +677,7 @@ inline auto var(std::string s)
...
@@ -624,7 +677,7 @@ inline auto var(std::string s)
inline
auto
name
(
std
::
string
s
)
inline
auto
name
(
std
::
string
s
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
[
=
,
m_
s
=
std
::
move
(
s
)](
instruction_ref
ins
)
{
return
ins
->
name
()
==
m_
s
;
});
}
}
inline
auto
name_contains
(
const
std
::
string
&
name
)
inline
auto
name_contains
(
const
std
::
string
&
name
)
...
@@ -635,8 +688,8 @@ inline auto name_contains(const std::string& name)
...
@@ -635,8 +688,8 @@ inline auto name_contains(const std::string& name)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)](
instruction_ref
ins
)
{
return
make_basic_pred_matcher
([
=
,
m_
names
=
std
::
move
(
names
)](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
return
m_
names
.
count
(
ins
->
name
())
>
0
;
});
});
}
}
...
...
src/include/migraphx/module.hpp
View file @
cd4ab535
...
@@ -178,6 +178,8 @@ struct module
...
@@ -178,6 +178,8 @@ struct module
bool
has_instruction
(
instruction_ref
ins
)
const
;
bool
has_instruction
(
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
get_returns
()
const
;
std
::
size_t
size
()
const
;
std
::
size_t
size
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
end
()
const
;
instruction_ref
end
()
const
;
...
...
src/include/migraphx/msgpack.hpp
View file @
cd4ab535
...
@@ -26,10 +26,12 @@
...
@@ -26,10 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <functional>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
to_msgpack
(
const
value
&
v
,
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
writer
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/onnx.hpp
View file @
cd4ab535
...
@@ -37,7 +37,7 @@ struct onnx_options
...
@@ -37,7 +37,7 @@ struct onnx_options
std
::
size_t
default_dim_value
=
0
;
std
::
size_t
default_dim_value
=
0
;
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
/// parser throws)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
/// Explicitly specify the dims of an input
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
...
...
src/include/migraphx/op/argmax.hpp
View file @
cd4ab535
...
@@ -62,7 +62,7 @@ struct argmax
...
@@ -62,7 +62,7 @@ struct argmax
if
(
s0
.
dynamic
())
if
(
s0
.
dynamic
())
{
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
dyn_dims
[
axis
]
=
{
1
,
1
};
return
{
shape
::
int64_type
,
dyn_dims
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
}
else
else
...
...
src/include/migraphx/op/broadcast.hpp
View file @
cd4ab535
...
@@ -37,10 +37,13 @@ namespace op {
...
@@ -37,10 +37,13 @@ namespace op {
* 1 input version:
* 1 input version:
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
* that stays the same.
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
* ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1.
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
*
* with axis = 0
* For higher rank input shapes, axis is an offset parameter for the broadcasting.
* Such that this operator would work in the opposite direction of NumPy broadcasting
* (left-most to rightwards element-wise comparison)
* ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0
*
*
* 2 input version:
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
...
@@ -68,6 +71,9 @@ struct broadcast
...
@@ -68,6 +71,9 @@ struct broadcast
{
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"BROADCAST: Single dynamic input shape not supported. Use two inputs."
);
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"BROADCAST : axis "
+
migraphx
::
to_string
(
axis
)
+
MIGRAPHX_THROW
(
"BROADCAST : axis "
+
migraphx
::
to_string
(
axis
)
+
...
...
src/include/migraphx/op/concat.hpp
View file @
cd4ab535
...
@@ -134,7 +134,7 @@ struct concat
...
@@ -134,7 +134,7 @@ struct concat
}
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
,
0
};
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
};
return
{
inputs
[
0
].
type
(),
new_dims
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
}
else
else
...
...
src/include/migraphx/op/contiguous.hpp
View file @
cd4ab535
...
@@ -48,7 +48,7 @@ struct contiguous
...
@@ -48,7 +48,7 @@ struct contiguous
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
.
front
();
auto
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
()
)
if
(
s0
.
dynamic
())
{
{
return
s0
;
return
s0
;
}
}
...
...
src/include/migraphx/op/convert.hpp
View file @
cd4ab535
...
@@ -66,7 +66,17 @@ struct convert : unary<convert>
...
@@ -66,7 +66,17 @@ struct convert : unary<convert>
auto
type
=
target_type
;
auto
type
=
target_type
;
return
[
type
](
auto
x
)
{
return
[
type
](
auto
x
)
{
auto
y
=
x
;
auto
y
=
x
;
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
y
=
std
::
min
(
std
::
max
(
as
(
x
),
as
.
min
()),
as
.
max
());
});
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
// clamping value between target_type's max and min doesn't work for NaNs,
if
(
std
::
isnan
(
x
))
{
y
=
as
.
nan
();
}
else
{
y
=
std
::
min
(
std
::
max
(
as
(
x
),
as
.
min
()),
as
.
max
());
}
});
return
y
;
return
y
;
};
};
}
}
...
...
src/include/migraphx/op/convolution.hpp
View file @
cd4ab535
...
@@ -24,9 +24,12 @@
...
@@ -24,9 +24,12 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -35,6 +38,10 @@ namespace migraphx {
...
@@ -35,6 +38,10 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Convolution operator. Does not support optimal dimensions for spatial dimensions. Returns empty
* optimals.
*/
struct
convolution
struct
convolution
{
{
std
::
vector
<
std
::
size_t
>
padding
=
{
0
,
0
};
std
::
vector
<
std
::
size_t
>
padding
=
{
0
,
0
};
...
@@ -145,7 +152,7 @@ struct convolution
...
@@ -145,7 +152,7 @@ struct convolution
else
else
{
{
auto
l
=
input_shape
.
lens
().
at
(
0
);
auto
l
=
input_shape
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
output_dyn_dims
.
push_back
({
l
,
l
});
}
}
};
};
...
@@ -162,25 +169,30 @@ struct convolution
...
@@ -162,25 +169,30 @@ struct convolution
if
(
x_shape
.
dynamic
())
if
(
x_shape
.
dynamic
())
{
{
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
std
::
set
<
std
::
size_t
>
optimals
{};
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
ceil_div
(
x
.
opt
,
s
)});
std
::
transform
(
x
.
optimals
.
begin
(),
x
.
optimals
.
end
(),
std
::
inserter
(
optimals
,
optimals
.
begin
()),
[
&
](
auto
o
)
{
return
ceil_div
(
o
,
s
);
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
optimals
});
}
}
else
else
{
{
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
,
0
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
});
}
}
}
}
}
}
else
else
{
{
// Does not compute for optimals
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
opt_spatial_dims
=
calc_conv_lens
(
x_shape
.
opt_lens
(),
w_shape
.
opt_lens
());
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
output_dyn_dims
.
push_back
(
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]
});
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
{}
});
}
}
}
}
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
...
@@ -201,6 +213,37 @@ struct convolution
...
@@ -201,6 +213,37 @@ struct convolution
check_attribute_size
();
check_attribute_size
();
return
stride
.
size
();
return
stride
.
size
();
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
std
::
vector
<
std
::
size_t
>
new_padding
;
if
(
padding_mode
!=
op
::
padding_mode_t
::
default_
)
{
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
new_padding
=
padding_mode
==
op
::
same_upper
?
calc_dyn_auto_pad
(
input_lens
,
weights_lens
,
stride
,
dilation
,
true
)
:
calc_dyn_auto_pad
(
input_lens
,
weights_lens
,
stride
,
dilation
,
false
);
output_shape
=
compute_padded_shape
(
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
(),
new_padding
,
stride
,
dilation
);
}
else
{
new_padding
=
padding
;
if
(
output_shape
.
dynamic
())
{
output_shape
=
normalize_compute_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()});
}
}
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
migraphx
::
convolution
(
output
,
input
,
weights
,
new_padding
,
stride
,
group
);
});
return
result
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/dequantizelinear.hpp
View file @
cd4ab535
...
@@ -37,10 +37,23 @@ namespace op {
...
@@ -37,10 +37,23 @@ namespace op {
struct
dequantizelinear
struct
dequantizelinear
{
{
value
attributes
()
const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return
{{
"pointwise"
,
true
}};
}
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_dims
();
check_shapes
{
inputs
,
*
this
}.
same_dims
().
has
(
2
,
3
);
if
(
inputs
.
size
()
==
3
and
inputs
[
0
].
type
()
!=
inputs
[
2
].
type
())
{
MIGRAPHX_THROW
(
"DEQUANTIZELINEAR: Zero point and input should be the same type."
);
}
return
{
inputs
[
1
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
return
{
inputs
[
1
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
}
...
...
src/include/migraphx/op/flatten.hpp
View file @
cd4ab535
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -59,27 +60,22 @@ struct flatten
...
@@ -59,27 +60,22 @@ struct flatten
auto
s
=
inputs
[
0
];
auto
s
=
inputs
[
0
];
if
(
s
.
dynamic
())
if
(
s
.
dynamic
())
{
{
// Doesn't handle optimals
auto
min_lens
=
s
.
min_lens
();
auto
min_lens
=
s
.
min_lens
();
auto
max_lens
=
s
.
max_lens
();
auto
max_lens
=
s
.
max_lens
();
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
{}};
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
{}};
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
return
{
s
.
type
(),
{
x
,
y
}};
}
}
else
else
...
...
src/include/migraphx/op/gathernd.hpp
View file @
cd4ab535
This diff is collapsed.
Click to expand it.
src/include/migraphx/op/multibroadcast.hpp
View file @
cd4ab535
This diff is collapsed.
Click to expand it.
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
cd4ab535
This diff is collapsed.
Click to expand it.
src/include/migraphx/op/pointwise.hpp
View file @
cd4ab535
This diff is collapsed.
Click to expand it.
src/include/migraphx/op/pooling.hpp
View file @
cd4ab535
This diff is collapsed.
Click to expand it.
src/include/migraphx/op/quant_convolution.hpp
View file @
cd4ab535
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
9
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment