Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cd4ab535
"vscode:/vscode.git/clone" did not exist on "3cbc6a5c01120b59cf38378d5f18836bcd90f6c9"
Commit
cd4ab535
authored
Jun 20, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
3891ee58
a0fa3742
Changes
279
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
286 additions
and
100 deletions
+286
-100
src/include/migraphx/hash.hpp
src/include/migraphx/hash.hpp
+47
-0
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+5
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+87
-34
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+2
-0
src/include/migraphx/msgpack.hpp
src/include/migraphx/msgpack.hpp
+2
-0
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+1
-1
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+1
-1
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+10
-4
src/include/migraphx/op/concat.hpp
src/include/migraphx/op/concat.hpp
+1
-1
src/include/migraphx/op/contiguous.hpp
src/include/migraphx/op/contiguous.hpp
+1
-1
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+11
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+50
-7
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+14
-1
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+4
-8
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+1
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+14
-10
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+2
-2
src/include/migraphx/op/pointwise.hpp
src/include/migraphx/op/pointwise.hpp
+5
-4
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+15
-23
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+13
-0
No files found.
src/include/migraphx/hash.hpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
std
::
size_t
hash_value
(
const
T
&
v
)
{
return
std
::
hash
<
T
>
{}(
v
);
}
template
<
class
T
>
void
hash_combine
(
std
::
size_t
&
seed
,
const
T
&
v
)
{
seed
^=
hash_value
(
v
)
+
0x9e3779b9
+
(
seed
<<
6u
)
+
(
seed
>>
2u
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
src/include/migraphx/instruction.hpp
View file @
cd4ab535
...
@@ -136,6 +136,9 @@ struct instruction
...
@@ -136,6 +136,9 @@ struct instruction
operation
normalized_operator
()
const
;
operation
normalized_operator
()
const
;
std
::
size_t
get_target_id
()
const
;
void
set_target_id
(
std
::
size_t
tid
);
void
debug_print
()
const
;
void
debug_print
()
const
;
static
void
print
(
std
::
ostream
&
os
,
static
void
print
(
std
::
ostream
&
os
,
...
@@ -172,7 +175,8 @@ struct instruction
...
@@ -172,7 +175,8 @@ struct instruction
std
::
vector
<
instruction_ref
>
arguments
;
std
::
vector
<
instruction_ref
>
arguments
;
std
::
vector
<
module_ref
>
module_args
;
std
::
vector
<
module_ref
>
module_args
;
literal
lit
;
literal
lit
;
bool
normalized
=
false
;
bool
normalized
=
false
;
std
::
size_t
target_id
=
0
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
cd4ab535
...
@@ -35,6 +35,10 @@
...
@@ -35,6 +35,10 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#ifndef MIGRAPHX_USE_TYPE_ERASED_MATCHERS
#define MIGRAPHX_USE_TYPE_ERASED_MATCHERS 0
#endif
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -103,6 +107,13 @@ struct predicate_matcher
...
@@ -103,6 +107,13 @@ struct predicate_matcher
}
}
};
};
/// Convert a predicate function into a matcher
template
<
class
P
>
predicate_matcher
<
P
>
make_predicate_matcher
(
P
p
)
{
return
{
p
};
}
/// Convert a function into a matcher
/// Convert a function into a matcher
template
<
class
F
>
template
<
class
F
>
struct
function_matcher
struct
function_matcher
...
@@ -124,14 +135,14 @@ template <class M>
...
@@ -124,14 +135,14 @@ template <class M>
auto
bind_match
(
M
m
,
std
::
string
name
)
auto
bind_match
(
M
m
,
std
::
string
name
)
{
{
return
make_function_matcher
(
return
make_function_matcher
(
[
=
,
name
=
std
::
move
(
name
)](
matcher_context
&
ctx
,
[
=
,
m_
name
=
std
::
move
(
name
)](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
)
if
(
result
)
{
{
if
(
not
ctx
.
has_instruction
(
ins
))
if
(
not
ctx
.
has_instruction
(
ins
))
return
nullopt
;
return
nullopt
;
ctx
.
instructions
[
name
]
=
ins
;
ctx
.
instructions
[
m_
name
]
=
ins
;
}
}
return
result
;
return
result
;
});
});
...
@@ -183,14 +194,26 @@ struct id_matcher
...
@@ -183,14 +194,26 @@ struct id_matcher
template
<
class
M
>
template
<
class
M
>
struct
basic_matcher
;
struct
basic_matcher
;
struct
any_matcher
;
template
<
class
M
>
struct
type_erased_matcher
{
#if MIGRAPHX_USE_TYPE_ERASED_MATCHERS
using
type
=
any_matcher
;
#else
using
type
=
basic_matcher
<
M
>
;
#endif
};
template
<
class
M
>
template
<
class
M
>
basic
_matcher
<
M
>
make_basic_matcher
(
M
m
);
typename
type_erased
_matcher
<
M
>
::
type
make_basic_matcher
(
M
m
);
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
auto
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
auto
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
template
<
class
M
>
...
@@ -222,38 +245,38 @@ struct basic_matcher
...
@@ -222,38 +245,38 @@ struct basic_matcher
auto
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
auto
match
(
matcher_context
&
ctx
,
instruction_ref
ins
)
const
{
return
m
.
match
(
ctx
,
ins
);
}
};
};
/// Create a typed-erased matcher
using
any_matcher_base
=
basic_matcher
<
function_matcher
<
std
::
function
<
optional
<
instruction_ref
>
(
matcher_context
&
,
instruction_ref
)
>>>
;
struct
any_matcher
:
any_matcher_base
{
template
<
class
M
>
any_matcher
(
M
mm
)
:
any_matcher_base
({[
=
](
auto
&
ctx
,
auto
ins
)
{
return
mm
.
match
(
ctx
,
ins
);
}})
{
}
};
/// Create a basic matcher from a matcher
/// Create a basic matcher from a matcher
template
<
class
M
>
template
<
class
M
>
basic
_matcher
<
M
>
make_basic_matcher
(
M
m
)
typename
type_erased
_matcher
<
M
>
::
type
make_basic_matcher
(
M
m
)
{
{
return
{
m
};
return
{
m
};
}
}
/// Create a basic matcher from a function
/// Create a basic matcher from a function
template
<
class
F
>
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
)
auto
make_basic_fun_matcher
(
F
f
)
{
{
return
{{
f
}}
;
return
make_basic_matcher
(
make_function_matcher
(
f
))
;
}
}
/// Create a basic matcher from a predicate function
/// Create a basic matcher from a predicate function
template
<
class
P
>
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
)
auto
make_basic_pred_matcher
(
P
p
)
{
{
return
{{
p
}}
;
return
make_basic_matcher
(
make_predicate_matcher
(
p
))
;
}
}
/// Create a typed-erased matcher
using
any_matcher_base
=
basic_matcher
<
function_matcher
<
std
::
function
<
optional
<
instruction_ref
>
(
matcher_context
&
,
instruction_ref
)
>>>
;
struct
any_matcher
:
any_matcher_base
{
template
<
class
M
>
any_matcher
(
M
mm
)
:
any_matcher_base
({[
=
](
auto
&
ctx
,
auto
ins
)
{
return
mm
.
match
(
ctx
,
ins
);
}})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
struct name##_m \
...
@@ -347,31 +370,49 @@ match::matcher_result find_match(module& modl, M&& m)
...
@@ -347,31 +370,49 @@ match::matcher_result find_match(module& modl, M&& m)
}
}
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MATCHES
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_VALIDATE_MATCHES
)
/// Find matches for an instruction in the module
/// Find matches for an instruction in the module
for per section of matchers
template
<
class
Mod
,
class
...
Ms
>
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
Mod
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
void
find_matches
(
size_t
trace_pass
,
Mod
&
mod
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
const
#endif
#endif
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
int
trace
=
value_of
(
MIGRAPHX_TRACE_MATCHES
{});
bool
match
=
false
;
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool
validate
=
enabled
(
MIGRAPHX_VALIDATE_MATCHES
{});
bool
match
=
false
;
each_args
(
each_args
(
[
&
](
auto
&&
m
)
{
[
&
](
auto
&&
m
)
{
if
(
match
)
if
(
match
)
return
;
return
;
if
(
trace
>
1
)
if
(
trace
>
1
or
trace_pass
>
1
)
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
auto
r
=
match_instruction
(
get_module
(
mod
),
ins
,
m
.
matcher
());
if
(
r
.
result
==
get_module
(
mod
).
end
())
if
(
r
.
result
==
get_module
(
mod
).
end
())
return
;
return
;
if
(
trace
>
0
)
if
(
trace
>
0
or
trace_pass
>
0
)
{
{
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Matched by "
<<
get_type_name
(
m
)
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
ins
);
get_module
(
mod
).
debug_print
(
ins
);
}
}
// If its already invalid dont validate it again
bool
invalidated
=
validate
and
get_module
(
mod
).
validate
()
!=
get_module
(
mod
).
end
();
m
.
apply
(
mod
,
r
);
m
.
apply
(
mod
,
r
);
if
(
validate
and
not
invalidated
)
{
auto
invalid
=
get_module
(
mod
).
validate
();
if
(
invalid
!=
get_module
(
mod
).
end
())
{
std
::
cout
<<
"Invalid program from match: "
<<
get_type_name
(
m
)
<<
std
::
endl
;
std
::
cout
<<
"Invalid instructions: "
<<
std
::
endl
;
get_module
(
mod
).
debug_print
(
invalid
->
inputs
());
get_module
(
mod
).
debug_print
(
invalid
);
}
}
match
=
true
;
match
=
true
;
},
},
ms
...);
ms
...);
...
@@ -383,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms)
...
@@ -383,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms)
{
{
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
{
{
find_matches
(
mod
,
ins
,
ms
...);
find_matches
(
0
,
mod
,
ins
,
ms
...);
}
}
/// Find matches in a pass
template
<
class
Mod
,
class
...
Ms
>
void
find_matches
(
size_t
trace_pass
,
Mod
&
mod
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
get_module
(
mod
)))
{
find_matches
(
trace_pass
,
mod
,
ins
,
ms
...);
}
}
}
}
...
@@ -520,6 +571,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
...
@@ -520,6 +571,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
{
return
not
ins
->
get_shape
().
standard
();
return
not
ins
->
get_shape
().
standard
();
}
}
MIGRAPHX_PRED_MATCHER
(
dynamic_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
static_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
dynamic
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
{
return
ins
->
get_shape
().
broadcasted
();
return
ins
->
get_shape
().
broadcasted
();
...
@@ -612,9 +665,9 @@ auto skip_output(Ms... ms)
...
@@ -612,9 +665,9 @@ auto skip_output(Ms... ms)
inline
auto
var
(
std
::
string
s
)
inline
auto
var
(
std
::
string
s
)
{
{
return
make_basic_fun_matcher
(
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
[
=
,
m_
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
auto
it
=
ctx
.
instructions
.
find
(
m_
s
);
if
(
it
==
ctx
.
instructions
.
end
())
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
nullopt
;
return
it
->
second
;
return
it
->
second
;
...
@@ -624,7 +677,7 @@ inline auto var(std::string s)
...
@@ -624,7 +677,7 @@ inline auto var(std::string s)
inline
auto
name
(
std
::
string
s
)
inline
auto
name
(
std
::
string
s
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
[
=
,
m_
s
=
std
::
move
(
s
)](
instruction_ref
ins
)
{
return
ins
->
name
()
==
m_
s
;
});
}
}
inline
auto
name_contains
(
const
std
::
string
&
name
)
inline
auto
name_contains
(
const
std
::
string
&
name
)
...
@@ -635,8 +688,8 @@ inline auto name_contains(const std::string& name)
...
@@ -635,8 +688,8 @@ inline auto name_contains(const std::string& name)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)](
instruction_ref
ins
)
{
return
make_basic_pred_matcher
([
=
,
m_
names
=
std
::
move
(
names
)](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
return
m_
names
.
count
(
ins
->
name
())
>
0
;
});
});
}
}
...
...
src/include/migraphx/module.hpp
View file @
cd4ab535
...
@@ -178,6 +178,8 @@ struct module
...
@@ -178,6 +178,8 @@ struct module
bool
has_instruction
(
instruction_ref
ins
)
const
;
bool
has_instruction
(
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
get_returns
()
const
;
std
::
size_t
size
()
const
;
std
::
size_t
size
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
end
()
const
;
instruction_ref
end
()
const
;
...
...
src/include/migraphx/msgpack.hpp
View file @
cd4ab535
...
@@ -26,10 +26,12 @@
...
@@ -26,10 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <functional>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
to_msgpack
(
const
value
&
v
,
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
writer
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
std
::
vector
<
char
>
to_msgpack
(
const
value
&
v
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
std
::
vector
<
char
>&
buffer
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
value
from_msgpack
(
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/onnx.hpp
View file @
cd4ab535
...
@@ -37,7 +37,7 @@ struct onnx_options
...
@@ -37,7 +37,7 @@ struct onnx_options
std
::
size_t
default_dim_value
=
0
;
std
::
size_t
default_dim_value
=
0
;
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
/// parser throws)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
/// Explicitly specify the dims of an input
/// Explicitly specify the dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
...
...
src/include/migraphx/op/argmax.hpp
View file @
cd4ab535
...
@@ -62,7 +62,7 @@ struct argmax
...
@@ -62,7 +62,7 @@ struct argmax
if
(
s0
.
dynamic
())
if
(
s0
.
dynamic
())
{
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
dyn_dims
=
s0
.
dyn_dims
();
dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
dyn_dims
[
axis
]
=
{
1
,
1
};
return
{
shape
::
int64_type
,
dyn_dims
};
return
{
shape
::
int64_type
,
dyn_dims
};
}
}
else
else
...
...
src/include/migraphx/op/broadcast.hpp
View file @
cd4ab535
...
@@ -37,10 +37,13 @@ namespace op {
...
@@ -37,10 +37,13 @@ namespace op {
* 1 input version:
* 1 input version:
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
* that stays the same.
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
* ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1.
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
*
* with axis = 0
* For higher rank input shapes, axis is an offset parameter for the broadcasting.
* Such that this operator would work in the opposite direction of NumPy broadcasting
* (left-most to rightwards element-wise comparison)
* ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0
*
*
* 2 input version:
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
...
@@ -68,6 +71,9 @@ struct broadcast
...
@@ -68,6 +71,9 @@ struct broadcast
{
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"BROADCAST: Single dynamic input shape not supported. Use two inputs."
);
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"BROADCAST : axis "
+
migraphx
::
to_string
(
axis
)
+
MIGRAPHX_THROW
(
"BROADCAST : axis "
+
migraphx
::
to_string
(
axis
)
+
...
...
src/include/migraphx/op/concat.hpp
View file @
cd4ab535
...
@@ -134,7 +134,7 @@ struct concat
...
@@ -134,7 +134,7 @@ struct concat
}
}
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
auto
new_dims
=
inputs
[
0
].
dyn_dims
();
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
,
0
};
new_dims
[
axis
]
=
migraphx
::
shape
::
dynamic_dimension
{
new_min
,
new_max
};
return
{
inputs
[
0
].
type
(),
new_dims
};
return
{
inputs
[
0
].
type
(),
new_dims
};
}
}
else
else
...
...
src/include/migraphx/op/contiguous.hpp
View file @
cd4ab535
...
@@ -48,7 +48,7 @@ struct contiguous
...
@@ -48,7 +48,7 @@ struct contiguous
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
.
front
();
auto
s0
=
inputs
.
front
();
if
(
s0
.
dynamic
()
or
s0
.
standard
()
)
if
(
s0
.
dynamic
())
{
{
return
s0
;
return
s0
;
}
}
...
...
src/include/migraphx/op/convert.hpp
View file @
cd4ab535
...
@@ -66,7 +66,17 @@ struct convert : unary<convert>
...
@@ -66,7 +66,17 @@ struct convert : unary<convert>
auto
type
=
target_type
;
auto
type
=
target_type
;
return
[
type
](
auto
x
)
{
return
[
type
](
auto
x
)
{
auto
y
=
x
;
auto
y
=
x
;
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
y
=
std
::
min
(
std
::
max
(
as
(
x
),
as
.
min
()),
as
.
max
());
});
shape
::
visit
(
type
,
[
&
](
auto
as
)
{
// clamping value between target_type's max and min doesn't work for NaNs,
if
(
std
::
isnan
(
x
))
{
y
=
as
.
nan
();
}
else
{
y
=
std
::
min
(
std
::
max
(
as
(
x
),
as
.
min
()),
as
.
max
());
}
});
return
y
;
return
y
;
};
};
}
}
...
...
src/include/migraphx/op/convolution.hpp
View file @
cd4ab535
...
@@ -24,9 +24,12 @@
...
@@ -24,9 +24,12 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -35,6 +38,10 @@ namespace migraphx {
...
@@ -35,6 +38,10 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Convolution operator. Does not support optimal dimensions for spatial dimensions. Returns empty
* optimals.
*/
struct
convolution
struct
convolution
{
{
std
::
vector
<
std
::
size_t
>
padding
=
{
0
,
0
};
std
::
vector
<
std
::
size_t
>
padding
=
{
0
,
0
};
...
@@ -145,7 +152,7 @@ struct convolution
...
@@ -145,7 +152,7 @@ struct convolution
else
else
{
{
auto
l
=
input_shape
.
lens
().
at
(
0
);
auto
l
=
input_shape
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
output_dyn_dims
.
push_back
({
l
,
l
});
}
}
};
};
...
@@ -162,25 +169,30 @@ struct convolution
...
@@ -162,25 +169,30 @@ struct convolution
if
(
x_shape
.
dynamic
())
if
(
x_shape
.
dynamic
())
{
{
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
auto
x
=
x_shape
.
dyn_dims
()[
i
+
2
];
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
std
::
set
<
std
::
size_t
>
optimals
{};
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
ceil_div
(
x
.
opt
,
s
)});
std
::
transform
(
x
.
optimals
.
begin
(),
x
.
optimals
.
end
(),
std
::
inserter
(
optimals
,
optimals
.
begin
()),
[
&
](
auto
o
)
{
return
ceil_div
(
o
,
s
);
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
ceil_div
(
x
.
min
,
s
),
ceil_div
(
x
.
max
,
s
),
optimals
});
}
}
else
else
{
{
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
auto
od
=
ceil_div
(
x_shape
.
lens
()[
i
+
2
],
s
);
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
,
0
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
od
,
od
});
}
}
}
}
}
}
else
else
{
{
// Does not compute for optimals
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
min_spatial_dims
=
calc_conv_lens
(
x_shape
.
min_lens
(),
w_shape
.
max_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
max_spatial_dims
=
calc_conv_lens
(
x_shape
.
max_lens
(),
w_shape
.
min_lens
());
auto
opt_spatial_dims
=
calc_conv_lens
(
x_shape
.
opt_lens
(),
w_shape
.
opt_lens
());
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
output_dyn_dims
.
push_back
(
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]
});
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
{}
});
}
}
}
}
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
return
shape
{
x_shape
.
type
(),
output_dyn_dims
};
...
@@ -201,6 +213,37 @@ struct convolution
...
@@ -201,6 +213,37 @@ struct convolution
check_attribute_size
();
check_attribute_size
();
return
stride
.
size
();
return
stride
.
size
();
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
std
::
vector
<
std
::
size_t
>
new_padding
;
if
(
padding_mode
!=
op
::
padding_mode_t
::
default_
)
{
auto
input_lens
=
args
[
0
].
get_shape
().
lens
();
auto
weights_lens
=
args
[
1
].
get_shape
().
lens
();
new_padding
=
padding_mode
==
op
::
same_upper
?
calc_dyn_auto_pad
(
input_lens
,
weights_lens
,
stride
,
dilation
,
true
)
:
calc_dyn_auto_pad
(
input_lens
,
weights_lens
,
stride
,
dilation
,
false
);
output_shape
=
compute_padded_shape
(
args
[
0
].
get_shape
(),
args
[
1
].
get_shape
(),
new_padding
,
stride
,
dilation
);
}
else
{
new_padding
=
padding
;
if
(
output_shape
.
dynamic
())
{
output_shape
=
normalize_compute_shape
({
args
.
at
(
0
).
get_shape
(),
args
.
at
(
1
).
get_shape
()});
}
}
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input
,
auto
weights
)
{
migraphx
::
convolution
(
output
,
input
,
weights
,
new_padding
,
stride
,
group
);
});
return
result
;
}
};
};
}
// namespace op
}
// namespace op
...
...
src/include/migraphx/op/dequantizelinear.hpp
View file @
cd4ab535
...
@@ -37,10 +37,23 @@ namespace op {
...
@@ -37,10 +37,23 @@ namespace op {
struct
dequantizelinear
struct
dequantizelinear
{
{
value
attributes
()
const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return
{{
"pointwise"
,
true
}};
}
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_dims
();
check_shapes
{
inputs
,
*
this
}.
same_dims
().
has
(
2
,
3
);
if
(
inputs
.
size
()
==
3
and
inputs
[
0
].
type
()
!=
inputs
[
2
].
type
())
{
MIGRAPHX_THROW
(
"DEQUANTIZELINEAR: Zero point and input should be the same type."
);
}
return
{
inputs
[
1
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
return
{
inputs
[
1
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
}
...
...
src/include/migraphx/op/flatten.hpp
View file @
cd4ab535
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -59,27 +60,22 @@ struct flatten
...
@@ -59,27 +60,22 @@ struct flatten
auto
s
=
inputs
[
0
];
auto
s
=
inputs
[
0
];
if
(
s
.
dynamic
())
if
(
s
.
dynamic
())
{
{
// Doesn't handle optimals
auto
min_lens
=
s
.
min_lens
();
auto
min_lens
=
s
.
min_lens
();
auto
max_lens
=
s
.
max_lens
();
auto
max_lens
=
s
.
max_lens
();
auto
opt_lens
=
s
.
opt_lens
();
// If any of the opt values is 0, output opt will be 0
// If any of the opt values is 0, output opt will be 0
shape
::
dynamic_dimension
x
=
{
shape
::
dynamic_dimension
x
=
{
std
::
accumulate
(
std
::
accumulate
(
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
min_lens
.
begin
(),
min_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
std
::
accumulate
(
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
max_lens
.
begin
(),
max_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
opt_lens
.
begin
(),
{}};
opt_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{})};
shape
::
dynamic_dimension
y
=
{
shape
::
dynamic_dimension
y
=
{
std
::
accumulate
(
std
::
accumulate
(
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
min_lens
.
begin
()
+
axis
,
min_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
std
::
accumulate
(
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
max_lens
.
begin
()
+
axis
,
max_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
std
::
accumulate
(
{}};
opt_lens
.
begin
()
+
axis
,
opt_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{}),
};
return
{
s
.
type
(),
{
x
,
y
}};
return
{
s
.
type
(),
{
x
,
y
}};
}
}
else
else
...
...
src/include/migraphx/op/gathernd.hpp
View file @
cd4ab535
...
@@ -121,7 +121,7 @@ struct gathernd
...
@@ -121,7 +121,7 @@ struct gathernd
// A rank 0 output is a scalar
// A rank 0 output is a scalar
if
(
output_ndim
==
0
)
if
(
output_ndim
==
0
)
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
,
0
})});
return
shape
(
data_shape
.
type
(),
{
shape
::
dynamic_dimension
({
1
,
1
})});
// Part of the output shape comes from indices tensor, part from data tensor
// Part of the output shape comes from indices tensor, part from data tensor
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
std
::
vector
<
shape
::
dynamic_dimension
>
output_dims
(
output_ndim
);
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
cd4ab535
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -36,9 +36,9 @@ namespace op {
...
@@ -36,9 +36,9 @@ namespace op {
/**
/**
* Broadcast multiple dimensions between two tensors.
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator:
one
input and
two
inputs.
* Two versions of this operator:
1
input and
2+
inputs.
* One input version uses output_lens attribute and broadcasts to it.
* One input version uses output_lens attribute and broadcasts to it.
*
Two
inputs version broadcasts
both
input
s
to the common shape at evaluation time.
*
2+
inputs version broadcasts
first
input to the common shape at evaluation time.
*/
*/
struct
multibroadcast
struct
multibroadcast
{
{
...
@@ -57,12 +57,12 @@ struct multibroadcast
...
@@ -57,12 +57,12 @@ struct multibroadcast
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
_at_least
(
1
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
t
=
inputs
.
at
(
0
).
type
();
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
if
(
s0
.
max_lens
().
empty
()
)
if
(
s0
.
ndim
()
<
1
)
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should be > 0"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should be > 0"
);
}
}
...
@@ -81,6 +81,9 @@ struct multibroadcast
...
@@ -81,6 +81,9 @@ struct multibroadcast
if
(
inputs
.
size
()
==
1
)
if
(
inputs
.
size
()
==
1
)
{
{
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs."
);
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
...
@@ -102,19 +105,20 @@ struct multibroadcast
...
@@ -102,19 +105,20 @@ struct multibroadcast
}
}
else
else
{
{
//
two
inputs
//
2+
inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
std
::
any_of
(
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
.
dynamic
()
;
})
)
{
{
if
(
not
output_dyn_dims
.
empty
())
if
(
not
output_dyn_dims
.
empty
())
{
{
return
{
t
,
output_dyn_dims
};
return
{
t
,
output_dyn_dims
};
}
}
return
{
t
,
compute_
broadcasted
_dyn_dims
(
s0
,
s1
)};
return
{
t
,
compute_
common
_dyn_dims
(
inputs
)};
}
}
else
else
{
{
auto
bcast_lens
=
compute_broadcasted_lens
(
s0
.
lens
(),
s1
.
lens
());
// output_lens will not be set for 2+ input version
auto
bcast_lens
=
compute_common_lens
(
inputs
);
auto
offset
=
bcast_lens
.
size
()
-
s0
.
lens
().
size
();
auto
offset
=
bcast_lens
.
size
()
-
s0
.
lens
().
size
();
auto
bcast_strides
=
make_bcast_strides
(
bcast_lens
,
offset
);
auto
bcast_strides
=
make_bcast_strides
(
bcast_lens
,
offset
);
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
return
{
t
,
std
::
move
(
bcast_lens
),
std
::
move
(
bcast_strides
)};
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
cd4ab535
...
@@ -119,8 +119,8 @@ struct nonmaxsuppression
...
@@ -119,8 +119,8 @@ struct nonmaxsuppression
fixed_shape_error_check
();
fixed_shape_error_check
();
}
}
std
::
vector
<
shape
::
dynamic_dimension
>
out_lens
=
{};
std
::
vector
<
shape
::
dynamic_dimension
>
out_lens
=
{};
out_lens
.
push_back
({
0
,
max_num_boxes
,
0
});
out_lens
.
push_back
({
0
,
max_num_boxes
});
out_lens
.
push_back
({
3
,
3
,
0
});
out_lens
.
push_back
({
3
,
3
});
return
{
shape
::
int64_type
,
out_lens
};
return
{
shape
::
int64_type
,
out_lens
};
}
}
else
else
...
...
src/include/migraphx/op/pointwise.hpp
View file @
cd4ab535
...
@@ -45,14 +45,15 @@ struct pointwise
...
@@ -45,14 +45,15 @@ struct pointwise
{
{
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
}
}
auto
*
pm
=
mods
.
front
();
auto
*
pm
=
mods
.
front
();
if
(
pm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"pointwise should have only one output."
);
if
(
inputs
.
empty
())
MIGRAPHX_THROW
(
"pointwise should have at least one input"
);
auto
pnames
=
pm
->
get_parameter_names
();
auto
pnames
=
pm
->
get_parameter_names
();
std
::
sort
(
pnames
.
begin
(),
pnames
.
end
());
std
::
sort
(
pnames
.
begin
(),
pnames
.
end
());
check_shapes
{
inputs
,
*
this
}.
has
(
pnames
.
size
()).
same_dims
();
check_shapes
{
inputs
,
*
this
}.
has
(
pnames
.
size
()).
same_dims
();
if
(
pm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"submodule should have only one output."
);
auto
type
=
pm
->
get_output_shapes
().
front
().
type
();
auto
type
=
pm
->
get_output_shapes
().
front
().
type
();
// Scalar output if all inputs are scalar
// Scalar output if all inputs are scalar
...
...
src/include/migraphx/op/pooling.hpp
View file @
cd4ab535
...
@@ -89,25 +89,17 @@ struct pooling
...
@@ -89,25 +89,17 @@ struct pooling
std
::
vector
<
std
::
size_t
>
output_lens
{};
std
::
vector
<
std
::
size_t
>
output_lens
{};
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
{
if
(
input_lens
[
i
+
2
]
==
0
)
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
{
if
(
padding
.
size
()
==
2
*
kdims
)
// handle opt = 0
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
output_lens
.
push_back
(
0
);
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
}
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
else
std
::
size_t
len
=
{
(
ceil_mode
)
std
::
size_t
padding_factor
=
2
*
padding
[
i
];
?
dim_size
/
stride
[
i
]
+
if
(
padding
.
size
()
==
2
*
kdims
)
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
:
dim_size
/
stride
[
i
];
// floor divide
assert
(
input_lens
[
i
+
2
]
+
padding_factor
>=
lengths
[
i
]);
output_lens
.
push_back
(
len
+
1
);
std
::
size_t
dim_size
=
input_lens
[
i
+
2
]
+
padding_factor
-
lengths
[
i
];
std
::
size_t
len
=
(
ceil_mode
)
?
dim_size
/
stride
[
i
]
+
static_cast
<
std
::
size_t
>
((
dim_size
%
stride
[
i
]
!=
0
))
// ceil uint divide
:
dim_size
/
stride
[
i
];
// floor divide
output_lens
.
push_back
(
len
+
1
);
}
}
}
return
output_lens
;
return
output_lens
;
}
}
...
@@ -134,19 +126,19 @@ struct pooling
...
@@ -134,19 +126,19 @@ struct pooling
{
{
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
,
1
});
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
1
,
1
});
}
}
return
{
input
.
type
(),
output_dyn_dims
};
return
{
input
.
type
(),
output_dyn_dims
};
}
}
else
else
{
{
// does not compute for optimals
auto
min_spatial_dims
=
calc_spatial_dim_out
(
input
.
min_lens
(),
kdims
);
auto
min_spatial_dims
=
calc_spatial_dim_out
(
input
.
min_lens
(),
kdims
);
auto
max_spatial_dims
=
calc_spatial_dim_out
(
input
.
max_lens
(),
kdims
);
auto
max_spatial_dims
=
calc_spatial_dim_out
(
input
.
max_lens
(),
kdims
);
auto
opt_spatial_dims
=
calc_spatial_dim_out
(
input
.
opt_lens
(),
kdims
);
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
for
(
size_t
i
=
0
;
i
<
kdims
;
++
i
)
{
{
output_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
output_dyn_dims
.
push_back
(
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
opt_spatial_dims
[
i
]
});
shape
::
dynamic_dimension
{
min_spatial_dims
[
i
],
max_spatial_dims
[
i
],
{}
});
}
}
return
{
input
.
type
(),
output_dyn_dims
};
return
{
input
.
type
(),
output_dyn_dims
};
}
}
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
cd4ab535
...
@@ -25,8 +25,10 @@
...
@@ -25,8 +25,10 @@
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -114,6 +116,17 @@ struct quant_convolution
...
@@ -114,6 +116,17 @@ struct quant_convolution
check_attribute_size
();
check_attribute_size
();
return
stride
.
size
();
return
stride
.
size
();
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
input
,
auto
weights
)
{
migraphx
::
convolution
(
output
,
input
,
weights
,
padding
,
stride
,
group
);
});
});
return
result
;
}
};
};
}
// namespace op
}
// namespace op
...
...
Prev
1
2
3
4
5
6
7
8
9
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment