Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f94d77fc
Commit
f94d77fc
authored
Aug 04, 2021
by
Khalique Ahmed
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into mi100_opts
parents
03929873
6403d482
Changes
126
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
300 additions
and
16 deletions
+300
-16
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+8
-0
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+19
-3
src/include/migraphx/argument.hpp
src/include/migraphx/argument.hpp
+2
-0
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+26
-0
src/include/migraphx/lifetime.hpp
src/include/migraphx/lifetime.hpp
+18
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+2
-1
src/include/migraphx/op/as_shape.hpp
src/include/migraphx/op/as_shape.hpp
+2
-1
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+2
-1
src/include/migraphx/op/dequantizelinear.hpp
src/include/migraphx/op/dequantizelinear.hpp
+61
-0
src/include/migraphx/op/flatten.hpp
src/include/migraphx/op/flatten.hpp
+3
-2
src/include/migraphx/op/load.hpp
src/include/migraphx/op/load.hpp
+2
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+2
-1
src/include/migraphx/op/prefix_scan_op.hpp
src/include/migraphx/op/prefix_scan_op.hpp
+1
-1
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+71
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+2
-1
src/include/migraphx/op/scalar.hpp
src/include/migraphx/op/scalar.hpp
+2
-1
src/include/migraphx/op/scatter.hpp
src/include/migraphx/op/scatter.hpp
+71
-0
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+2
-1
src/include/migraphx/op/step.hpp
src/include/migraphx/op/step.hpp
+2
-1
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+2
-1
No files found.
src/include/migraphx/algorithm.hpp
View file @
f94d77fc
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -21,6 +22,13 @@ void transform_if(Iterator start, Iterator last, Output out, Predicate pred, F f
...
@@ -21,6 +22,13 @@ void transform_if(Iterator start, Iterator last, Output out, Predicate pred, F f
}
}
}
}
template
<
class
Iterator
,
class
T
,
class
BinaryOp
,
class
UnaryOp
>
T
transform_accumulate
(
Iterator
first
,
Iterator
last
,
T
init
,
BinaryOp
binop
,
UnaryOp
unaryop
)
{
return
std
::
inner_product
(
first
,
last
,
first
,
init
,
binop
,
[
&
](
auto
&&
x
,
auto
&&
)
{
return
unaryop
(
x
);
});
}
template
<
class
Iterator
,
class
Output
,
class
Predicate
>
template
<
class
Iterator
,
class
Output
,
class
Predicate
>
void
group_by
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
)
void
group_by
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
)
{
{
...
...
src/include/migraphx/allocation_model.hpp
View file @
f94d77fc
...
@@ -26,6 +26,8 @@ struct allocation_model
...
@@ -26,6 +26,8 @@ struct allocation_model
std
::
string
copy
()
const
;
std
::
string
copy
()
const
;
/// Create an allocation operator for the given shape
/// Create an allocation operator for the given shape
operation
allocate
(
const
shape
&
s
)
const
;
operation
allocate
(
const
shape
&
s
)
const
;
/// Create a preallocated operator for the given shape
operation
preallocate
(
const
shape
&
s
,
const
std
::
string
&
id
)
const
;
};
};
#else
#else
...
@@ -38,6 +40,7 @@ struct allocation_model
...
@@ -38,6 +40,7 @@ struct allocation_model
* std::string name() const;
* std::string name() const;
* std::string copy() const;
* std::string copy() const;
* operation allocate(const shape& s) const;
* operation allocate(const shape& s) const;
* operation preallocate(const shape& s,std::string id) const;
* };
* };
*
*
*/
*/
...
@@ -123,6 +126,12 @@ struct allocation_model
...
@@ -123,6 +126,12 @@ struct allocation_model
return
(
*
this
).
private_detail_te_get_handle
().
allocate
(
s
);
return
(
*
this
).
private_detail_te_get_handle
().
allocate
(
s
);
}
}
operation
preallocate
(
const
shape
&
s
,
std
::
string
id
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
preallocate
(
s
,
std
::
move
(
id
));
}
friend
bool
is_shared
(
const
allocation_model
&
private_detail_x
,
friend
bool
is_shared
(
const
allocation_model
&
private_detail_x
,
const
allocation_model
&
private_detail_y
)
const
allocation_model
&
private_detail_y
)
{
{
...
@@ -137,9 +146,10 @@ struct allocation_model
...
@@ -137,9 +146,10 @@ struct allocation_model
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
copy
()
const
=
0
;
virtual
std
::
string
copy
()
const
=
0
;
virtual
operation
allocate
(
const
shape
&
s
)
const
=
0
;
virtual
operation
allocate
(
const
shape
&
s
)
const
=
0
;
virtual
operation
preallocate
(
const
shape
&
s
,
std
::
string
id
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -180,6 +190,12 @@ struct allocation_model
...
@@ -180,6 +190,12 @@ struct allocation_model
return
private_detail_te_value
.
allocate
(
s
);
return
private_detail_te_value
.
allocate
(
s
);
}
}
operation
preallocate
(
const
shape
&
s
,
std
::
string
id
)
const
override
{
return
private_detail_te_value
.
preallocate
(
s
,
std
::
move
(
id
));
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
};
};
...
...
src/include/migraphx/argument.hpp
View file @
f94d77fc
...
@@ -60,6 +60,8 @@ struct argument : raw_data<argument>
...
@@ -60,6 +60,8 @@ struct argument : raw_data<argument>
argument
reshape
(
const
shape
&
s
)
const
;
argument
reshape
(
const
shape
&
s
)
const
;
argument
copy
()
const
;
/// Make copy of the argument that is always sharing the data
/// Make copy of the argument that is always sharing the data
argument
share
()
const
;
argument
share
()
const
;
...
...
src/include/migraphx/common.hpp
0 → 100644
View file @
f94d77fc
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
operation
;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
);
instruction_ref
add_common_op
(
module
&
m
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP
src/include/migraphx/lifetime.hpp
0 → 100755
View file @
f94d77fc
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
enum
class
lifetime
{
local
,
global
,
borrow
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP
src/include/migraphx/matcher.hpp
View file @
f94d77fc
...
@@ -626,7 +626,8 @@ auto tree(M main_op, Ms... ms)
...
@@ -626,7 +626,8 @@ auto tree(M main_op, Ms... ms)
if
(
idx
!=
leafs
.
size
())
if
(
idx
!=
leafs
.
size
())
return
nullopt
;
return
nullopt
;
// Use explicit captures to workaround ICE on gcc
// Use explicit captures to workaround ICE on gcc
bool
found
=
sequence_c
<
sizeof
...(
Ms
)
>
([
&
ms
...,
&
ctx
,
&
leafs
](
auto
...
is
)
{
// Capture by value to workaround compile error on gcc 9
bool
found
=
sequence_c
<
sizeof
...(
Ms
)
>
([
ms
...,
&
ctx
,
&
leafs
](
auto
...
is
)
{
return
fold
(
lazy_and
{})(
ctx
.
lazy_match
(
ms
,
leafs
[
is
])...)();
return
fold
(
lazy_and
{})(
ctx
.
lazy_match
(
ms
,
leafs
[
is
])...)();
});
});
if
(
not
found
)
if
(
not
found
)
...
...
src/include/migraphx/op/as_shape.hpp
View file @
f94d77fc
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -35,7 +36,7 @@ struct as_shape
...
@@ -35,7 +36,7 @@ struct as_shape
{
{
return
args
.
front
().
reshape
(
output_shape
);
return
args
.
front
().
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/broadcast.hpp
View file @
f94d77fc
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -66,7 +67,7 @@ struct broadcast
...
@@ -66,7 +67,7 @@ struct broadcast
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/dequantizelinear.hpp
0 → 100644
View file @
f94d77fc
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
dequantizelinear
{
std
::
string
name
()
const
{
return
"dequantizelinear"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
{
shape
::
float_type
,
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
x
=
args
.
at
(
0
);
auto
x_scale
=
args
.
at
(
1
);
std
::
vector
<
int8_t
>
zeros
(
output_shape
.
elements
(),
0
);
argument
x_zero_point
{{
x
.
get_shape
().
type
(),
output_shape
.
lens
()},
zeros
.
data
()};
if
(
args
.
size
()
==
3
)
{
x_zero_point
=
args
.
at
(
2
);
}
argument
result
{
output_shape
};
visit_all
(
x
,
x_zero_point
)([
&
](
auto
input
,
auto
zero_pts
)
{
visit_all
(
result
,
x_scale
)([
&
](
auto
output
,
auto
scales
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
static_cast
<
double
>
(
static_cast
<
int64_t
>
(
input
[
i
])
-
static_cast
<
int64_t
>
(
zero_pts
[
i
]))
*
scales
[
i
];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/flatten.hpp
View file @
f94d77fc
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -38,7 +39,7 @@ struct flatten
...
@@ -38,7 +39,7 @@ struct flatten
std
::
string
name
()
const
{
return
"flatten"
;
}
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
)
.
standard
()
;
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
x
=
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
...
@@ -50,7 +51,7 @@ struct flatten
...
@@ -50,7 +51,7 @@ struct flatten
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/load.hpp
View file @
f94d77fc
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -36,7 +37,7 @@ struct load
...
@@ -36,7 +37,7 @@ struct load
MIGRAPHX_THROW
(
"Load access is out of bounds"
);
MIGRAPHX_THROW
(
"Load access is out of bounds"
);
return
argument
::
load
(
s
,
args
[
0
].
data
()
+
offset
);
return
argument
::
load
(
s
,
args
[
0
].
data
()
+
offset
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
load
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
load
&
op
)
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
f94d77fc
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -68,7 +69,7 @@ struct multibroadcast
...
@@ -68,7 +69,7 @@ struct multibroadcast
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/prefix_scan_op.hpp
100644 → 100755
View file @
f94d77fc
...
@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived>
...
@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived>
argument
compute
(
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
=
args
[
0
];
argument
result
=
args
[
0
]
.
copy
()
;
auto
s
=
result
.
get_shape
();
auto
s
=
result
.
get_shape
();
auto
slice
=
shape
{
s
.
type
(),
{
s
.
lens
()[
axis
]},
{
s
.
strides
()[
axis
]}};
auto
slice
=
shape
{
s
.
type
(),
{
s
.
lens
()[
axis
]},
{
s
.
strides
()[
axis
]}};
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
...
...
src/include/migraphx/op/quantizelinear.hpp
0 → 100644
View file @
f94d77fc
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
quantizelinear
{
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
==
3
)
{
return
{
inputs
[
2
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
return
{
shape
::
uint8_type
,
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
x
=
args
.
at
(
0
);
auto
y_scale
=
args
.
at
(
1
);
std
::
vector
<
int8_t
>
zeros
(
output_shape
.
elements
(),
0
);
argument
y_zero_point
{
output_shape
,
zeros
.
data
()};
if
(
args
.
size
()
==
3
)
{
y_zero_point
=
args
.
at
(
2
);
}
argument
result
{
output_shape
};
visit_all
(
result
,
y_zero_point
)([
&
](
auto
output
,
auto
zero_pts
)
{
x
.
visit
([
&
](
auto
input
)
{
y_scale
.
visit
([
&
](
auto
scales
)
{
using
quant_type
=
typename
decltype
(
output
)
::
value_type
;
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
round
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
});
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/reshape.hpp
View file @
f94d77fc
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -71,7 +72,7 @@ struct reshape
...
@@ -71,7 +72,7 @@ struct reshape
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/scalar.hpp
View file @
f94d77fc
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -39,7 +40,7 @@ struct scalar
...
@@ -39,7 +40,7 @@ struct scalar
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/scatter.hpp
0 → 100644
View file @
f94d77fc
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
value
attributes
()
const
{
value
normalize
;
normalize
[
"axis"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"scatter"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
standard
();
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// max dimension in axis
auto
axis_dim_size
=
output_shape
.
lens
()[
axis
];
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
update
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
ind_s
=
indices
.
get_shape
();
shape_for_each
(
ind_s
,
[
&
](
const
auto
&
idx
)
{
auto
out_idx
=
idx
;
auto
index
=
indices
[
ind_s
.
index
(
idx
)];
index
=
(
index
<
0
)
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
output_shape
.
index
(
out_idx
)]
=
update
[
ind_s
.
index
(
idx
)];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/squeeze.hpp
View file @
f94d77fc
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -77,7 +78,7 @@ struct squeeze
...
@@ -77,7 +78,7 @@ struct squeeze
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/step.hpp
View file @
f94d77fc
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -71,7 +72,7 @@ struct step
...
@@ -71,7 +72,7 @@ struct step
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/transpose.hpp
View file @
f94d77fc
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
...
@@ -63,7 +64,7 @@ struct transpose
...
@@ -63,7 +64,7 @@ struct transpose
{
{
return
args
[
0
].
reshape
(
output_shape
);
return
args
[
0
].
reshape
(
output_shape
);
}
}
bool
is_borrowed
()
const
{
return
true
;
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment