Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f268bc2
Commit
2f268bc2
authored
Jun 12, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
f75c5a38
aa7ff911
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
497 additions
and
31 deletions
+497
-31
src/include/migraphx/eliminate_contiguous.hpp
src/include/migraphx/eliminate_contiguous.hpp
+1
-1
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+1
-1
src/include/migraphx/filesystem.hpp
src/include/migraphx/filesystem.hpp
+4
-1
src/include/migraphx/gemm.hpp
src/include/migraphx/gemm.hpp
+4
-2
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+4
-4
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+13
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+68
-4
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+2
-1
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+131
-0
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+112
-1
src/include/migraphx/op/scatter.hpp
src/include/migraphx/op/scatter.hpp
+34
-8
src/include/migraphx/op/scatter_add.hpp
src/include/migraphx/op/scatter_add.hpp
+38
-0
src/include/migraphx/op/scatter_mul.hpp
src/include/migraphx/op/scatter_mul.hpp
+36
-0
src/include/migraphx/op/scatter_none.hpp
src/include/migraphx/op/scatter_none.hpp
+37
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+4
-1
src/include/migraphx/optional.hpp
src/include/migraphx/optional.hpp
+4
-1
src/include/migraphx/propagate_constant.hpp
src/include/migraphx/propagate_constant.hpp
+1
-1
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+1
-2
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+1
-1
No files found.
src/include/migraphx/eliminate_contiguous.hpp
View file @
2f268bc2
...
...
@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std
::
string
op_name
;
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_identity.hpp
View file @
2f268bc2
...
...
@@ -18,7 +18,7 @@ struct module;
struct
eliminate_identity
{
std
::
string
name
()
const
{
return
"eliminate_identity"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/filesystem.hpp
View file @
2f268bc2
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
...
...
src/include/migraphx/gemm.hpp
View file @
2f268bc2
...
...
@@ -3,7 +3,7 @@
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/
shape_for_each
.hpp>
#include <migraphx/
par_for
.hpp>
#include <migraphx/tensor_view.hpp>
namespace
migraphx
{
...
...
@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
...
...
src/include/migraphx/generate.hpp
View file @
2f268bc2
...
...
@@ -88,16 +88,16 @@ struct xorshift_generator
template
<
class
T
>
auto
generate_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
xorshf96_generator
<
T
>
{
seed
});
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
xorshf96_generator
<
T
>
{
seed
});
return
result
;
}
template
<
class
T
>
auto
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
[
=
]
{
return
value
;
});
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
[
=
]
{
return
value
;
});
return
result
;
}
...
...
src/include/migraphx/make_op.hpp
View file @
2f268bc2
...
...
@@ -9,7 +9,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
);
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
);
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template
<
class
Value
>
operation
make_op
(
const
std
::
string
&
name
,
const
Value
&
v
)
{
return
make_op_from_value
(
name
,
v
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
2f268bc2
...
...
@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template
<
class
M
>
struct
basic_matcher
;
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
);
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
struct
basic_matcher
...
...
@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
return
make_b
f
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
return
make_b
asic_fun
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
{
...
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct
matcher_result
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
struct
instruction_container
{
instruction_container
()
=
default
;
instruction_container
(
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
x
)
:
ins_map
(
std
::
move
(
x
))
{
}
instruction_ref
operator
[](
const
std
::
string
&
name
)
const
{
auto
it
=
ins_map
.
find
(
name
);
if
(
it
==
ins_map
.
end
())
MIGRAPHX_THROW
(
"Accessing name that wasn't bound in matcher: "
+
name
);
return
it
->
second
;
}
auto
find
(
const
std
::
string
&
name
)
const
{
return
ins_map
.
find
(
name
);
}
auto
begin
()
const
{
return
ins_map
.
cbegin
();
}
auto
end
()
const
{
return
ins_map
.
cend
();
}
bool
has_instructions_in
(
const
module
&
mod
)
const
{
return
std
::
all_of
(
ins_map
.
begin
(),
ins_map
.
end
(),
[
&
](
auto
&&
p
)
{
return
mod
.
has_instruction
(
p
.
second
);
});
}
private:
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
ins_map
;
};
instruction_container
instructions
;
instruction_ref
result
;
};
...
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result
.
result
=
ins
;
result
.
instructions
=
ctx
.
instructions
;
assert
(
result
.
instructions
.
has_instructions_in
(
mod
));
}
else
{
...
...
@@ -535,6 +581,18 @@ auto skip_output(Ms... ms)
});
}
inline
auto
var
(
std
::
string
s
)
{
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
it
->
second
;
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
...
...
@@ -698,10 +756,16 @@ auto skip_broadcasts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
}
template
<
class
...
Ms
>
auto
skip_broadcasts_converts
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
return
skip_broadcasts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts
_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
auto
l
=
ins
->
get_literal
();
...
...
src/include/migraphx/memory_coloring.hpp
View file @
2f268bc2
...
...
@@ -17,7 +17,7 @@ struct memory_coloring
std
::
string
allocation_op
{};
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/op/common.hpp
View file @
2f268bc2
...
...
@@ -22,7 +22,8 @@ enum padding_mode_t
enum
class
pooling_mode
{
average
,
max
max
,
lpnorm
};
// indicate rnn computation direction
...
...
src/include/migraphx/op/gathernd.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
gathernd
{
int
batch_dims
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
batch_dims
,
"batch_dims"
));
}
std
::
string
name
()
const
{
return
"gathernd"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
k
=
inputs
.
back
().
lens
().
back
();
if
(
k
>
r
-
batch_dims
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
{
auto
data_lens
=
inputs
.
front
().
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
();
auto
data_shape
=
data
.
get_shape
();
auto
data_shape_lens
=
data_shape
.
lens
();
auto
k
=
indices_shape
.
lens
().
back
();
const
auto
num_slice_dims
=
k
;
std
::
size_t
num_slices
=
std
::
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
k
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_batches
=
std
::
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
data_batch_stride
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
std
::
vector
<
std
::
size_t
>
sizes_from_slice_dims
(
num_slice_dims
);
{
auto
running_product
=
slice_size
;
for
(
std
::
size_t
i
=
0
;
i
<
num_slice_dims
;
++
i
)
{
sizes_from_slice_dims
[
num_slice_dims
-
1
-
i
]
=
running_product
;
running_product
*=
data_shape_lens
[
batch_dims
+
num_slice_dims
-
1
-
i
];
}
}
std
::
vector
<
std
::
size_t
>
input_slice_offsets
(
num_slices
);
par_for
(
num_slices
,
[
&
](
const
auto
i
)
{
std
::
size_t
batch_idx
=
i
/
num_slices_per_batch
;
auto
slice_indices
=
indices
.
begin
()
+
(
i
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
size_t
dim_idx
=
0
;
dim_idx
<
num_slice_dims
;
++
dim_idx
)
{
int64_t
index
=
*
(
slice_indices
+
dim_idx
);
const
std
::
size_t
input_dim_idx
=
batch_dims
+
dim_idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
if
(
index
<
-
static_cast
<
int64_t
>
(
input_dim
)
or
index
>=
static_cast
<
int64_t
>
(
input_dim
))
MIGRAPHX_THROW
(
"GatherND: index "
+
std
::
to_string
(
index
)
+
" is out of bounds for dim of len "
+
std
::
to_string
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
relative_slice_offset
+=
index
*
sizes_from_slice_dims
[
dim_idx
];
}
input_slice_offsets
[
i
]
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
});
par_for
(
num_slices
*
slice_size
,
[
&
](
const
auto
i
)
{
auto
slice_offset
=
input_slice_offsets
[
i
/
slice_size
];
output
[
i
]
=
data
[
slice_offset
+
i
%
slice_size
];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/pooling.hpp
View file @
2f268bc2
...
...
@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
...
...
@@ -27,6 +28,7 @@ struct pooling
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
lengths
=
{
1
,
1
};
bool
ceil_mode
=
false
;
int
lp_order
=
2
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -35,7 +37,8 @@ struct pooling
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
));
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
lp_order
,
"lp_order"
));
}
std
::
string
name
()
const
{
return
"pooling"
;
}
...
...
@@ -89,6 +92,114 @@ struct pooling
check_attribute_size
();
return
stride
.
size
();
}
struct
lpnorm_pool
{
int
p
=
0
;
lpnorm_pool
()
=
delete
;
explicit
lpnorm_pool
(
int
x
)
:
p
{
x
}
{};
template
<
class
T
>
double
init
()
const
{
return
0.0
;
}
double
operator
()(
double
x
,
double
y
)
const
{
return
x
+
std
::
pow
(
std
::
abs
(
y
),
p
);
}
double
final
(
double
x
,
std
::
size_t
)
const
{
return
std
::
pow
(
x
,
1.
/
p
);
}
};
struct
avg_pool
{
template
<
class
T
>
double
init
()
const
{
return
0.0
;
}
double
operator
()(
double
x
,
double
y
)
const
{
return
x
+
y
;
}
double
final
(
double
x
,
std
::
size_t
y
)
const
{
return
(
y
==
0
)
?
0.0
:
(
x
/
y
);
}
};
struct
max_pool
{
template
<
class
T
>
T
init
()
const
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
double
operator
()(
double
x
,
double
y
)
const
{
return
std
::
max
(
x
,
y
);
}
double
final
(
double
x
,
std
::
size_t
)
const
{
return
(
x
);
}
};
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
Op
op
)
const
{
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx_o
=
output_shape
.
multi
(
i
);
auto
n_dim
=
idx_o
.
size
();
std
::
vector
<
std
::
size_t
>
win_start
;
std
::
vector
<
std
::
size_t
>
win_size
;
for
(
std
::
size_t
dim
=
2
;
dim
<
n_dim
;
++
dim
)
{
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
lengths
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
}
shape
win_shape
{
output_shape
.
type
(),
win_size
};
auto
pool_size
=
win_shape
.
elements
();
double
output_val
=
op
.
template
init
<
Type
>();
shape_for_each
(
win_shape
,
[
&
](
auto
idx_w
)
{
auto
idx
=
idx_o
;
std
::
transform
(
idx_w
.
begin
(),
idx_w
.
end
(),
win_start
.
begin
(),
idx
.
begin
()
+
2
,
[](
auto
ii
,
auto
jj
)
{
return
ii
+
jj
;
});
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
idx
<
in_lens
)
{
output_val
=
op
(
output_val
,
input
[
in_s
.
index
(
idx
)]);
}
});
output
[
i
]
=
Type
(
op
.
final
(
output_val
,
pool_size
));
});
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
switch
(
mode
)
{
case
migraphx
::
op
::
pooling_mode
::
average
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
avg_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
max
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
max_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
lpnorm_pool
{
lp_order
});
break
;
}
});
return
result
;
}
};
}
// namespace op
...
...
src/include/migraphx/op/scatter.hpp
View file @
2f268bc2
...
...
@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
...
...
@@ -16,7 +17,17 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter
// The scatter operator fetches a subset of data given by an index array and then performs a
// reduction operation (add, multiply, or just set the data) on each element returned. We implement
// it as a separate derived struct for each of the three reduction methods. The related operator
// scatterND is a generalization that works on a set of 3 tensors of different ranks. The
// complementary operations are gather/gatherND.
//
// This is a template for deriving child structs from. Each child needs to define
// only a reduction() method. Names are automatically handled by the op_name template.
template
<
class
Derived
>
struct
scatter
:
op_name
<
Derived
>
{
int64_t
axis
=
0
;
...
...
@@ -33,29 +44,44 @@ struct scatter
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"scatter"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
standard
();
return
inputs
.
front
();
// If non-packed, this converts to a packed output while preserving permutation of tensor
return
inputs
.
front
().
with_lens
(
inputs
.
front
().
lens
());
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// max dimension in axis
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
// max dimension in each axis
auto
axis_dim_size
=
output_shape
.
lens
()[
axis
];
// cast all arguments as correct type
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
update
)
{
// copy all of data to output
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
ind_s
=
indices
.
get_shape
();
// iterate through items in shape
shape_for_each
(
ind_s
,
[
&
](
const
auto
&
idx
)
{
auto
out_idx
=
idx
;
auto
index
=
indices
[
ind_s
.
index
(
idx
)];
auto
out_idx
=
idx
;
// Overloaded tensor_view::() invokes indexing logic of
// std::size_t shape::index(std::size_t i) const
// which handles nonstandard shapes correctly
auto
index
=
indices
(
idx
.
begin
(),
idx
.
end
());
// normalize negative indexes (may be redundant after using
// normalize_compute_shape())
index
=
(
index
<
0
)
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
output_shape
.
index
(
out_idx
)]
=
update
[
ind_s
.
index
(
idx
)];
// look up the appropriate locations in output, using idx and out_idx.
// call reduction() method of derived struct to copy and reduce that element
self
.
reduction
()(
output
(
out_idx
.
begin
(),
out_idx
.
end
()),
update
(
idx
.
begin
(),
idx
.
end
()));
});
});
});
...
...
src/include/migraphx/op/scatter_add.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_add
:
scatter
<
scatter_add
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
+=
y
;
};
}
// name of this struct is automatically assigned by the op_name<>
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatter_mul.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_mul
:
scatter
<
scatter_mul
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
*=
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatter_none.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_none
:
scatter
<
scatter_none
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
2f268bc2
...
...
@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
...
...
@@ -86,7 +87,9 @@
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
...
...
src/include/migraphx/optional.hpp
View file @
2f268bc2
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
...
...
src/include/migraphx/propagate_constant.hpp
View file @
2f268bc2
...
...
@@ -15,7 +15,7 @@ struct module;
struct
propagate_constant
{
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/raw_data.hpp
View file @
2f268bc2
...
...
@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1)
template
<
class
T
,
class
...
Ts
>
auto
visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
auto
&&
s
=
x
.
get_shape
();
// cppcheck-suppress redundantInitialization
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
...
...
src/include/migraphx/rewrite_batchnorm.hpp
View file @
2f268bc2
...
...
@@ -16,7 +16,7 @@ struct module;
struct
rewrite_batchnorm
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment