Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2f268bc2
Commit
2f268bc2
authored
Jun 12, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
f75c5a38
aa7ff911
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
497 additions
and
31 deletions
+497
-31
src/include/migraphx/eliminate_contiguous.hpp
src/include/migraphx/eliminate_contiguous.hpp
+1
-1
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+1
-1
src/include/migraphx/filesystem.hpp
src/include/migraphx/filesystem.hpp
+4
-1
src/include/migraphx/gemm.hpp
src/include/migraphx/gemm.hpp
+4
-2
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+4
-4
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+13
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+68
-4
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/op/common.hpp
src/include/migraphx/op/common.hpp
+2
-1
src/include/migraphx/op/gathernd.hpp
src/include/migraphx/op/gathernd.hpp
+131
-0
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+112
-1
src/include/migraphx/op/scatter.hpp
src/include/migraphx/op/scatter.hpp
+34
-8
src/include/migraphx/op/scatter_add.hpp
src/include/migraphx/op/scatter_add.hpp
+38
-0
src/include/migraphx/op/scatter_mul.hpp
src/include/migraphx/op/scatter_mul.hpp
+36
-0
src/include/migraphx/op/scatter_none.hpp
src/include/migraphx/op/scatter_none.hpp
+37
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+4
-1
src/include/migraphx/optional.hpp
src/include/migraphx/optional.hpp
+4
-1
src/include/migraphx/propagate_constant.hpp
src/include/migraphx/propagate_constant.hpp
+1
-1
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+1
-2
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+1
-1
No files found.
src/include/migraphx/eliminate_contiguous.hpp
View file @
2f268bc2
...
...
@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std
::
string
op_name
;
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_identity.hpp
View file @
2f268bc2
...
...
@@ -18,7 +18,7 @@ struct module;
struct
eliminate_identity
{
std
::
string
name
()
const
{
return
"eliminate_identity"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/filesystem.hpp
View file @
2f268bc2
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
...
...
src/include/migraphx/gemm.hpp
View file @
2f268bc2
...
...
@@ -3,7 +3,7 @@
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/
shape_for_each
.hpp>
#include <migraphx/
par_for
.hpp>
#include <migraphx/tensor_view.hpp>
namespace
migraphx
{
...
...
@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
...
...
src/include/migraphx/generate.hpp
View file @
2f268bc2
...
...
@@ -88,16 +88,16 @@ struct xorshift_generator
template
<
class
T
>
auto
generate_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
xorshf96_generator
<
T
>
{
seed
});
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
xorshf96_generator
<
T
>
{
seed
});
return
result
;
}
template
<
class
T
>
auto
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
[
=
]
{
return
value
;
});
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
[
=
]
{
return
value
;
});
return
result
;
}
...
...
src/include/migraphx/make_op.hpp
View file @
2f268bc2
...
...
@@ -9,7 +9,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
);
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
);
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template
<
class
Value
>
operation
make_op
(
const
std
::
string
&
name
,
const
Value
&
v
)
{
return
make_op_from_value
(
name
,
v
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
2f268bc2
...
...
@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template
<
class
M
>
struct
basic_matcher
;
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
);
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
struct
basic_matcher
...
...
@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
return
make_b
f
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
return
make_b
asic_fun
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
{
...
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct
matcher_result
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
struct
instruction_container
{
instruction_container
()
=
default
;
instruction_container
(
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
x
)
:
ins_map
(
std
::
move
(
x
))
{
}
instruction_ref
operator
[](
const
std
::
string
&
name
)
const
{
auto
it
=
ins_map
.
find
(
name
);
if
(
it
==
ins_map
.
end
())
MIGRAPHX_THROW
(
"Accessing name that wasn't bound in matcher: "
+
name
);
return
it
->
second
;
}
auto
find
(
const
std
::
string
&
name
)
const
{
return
ins_map
.
find
(
name
);
}
auto
begin
()
const
{
return
ins_map
.
cbegin
();
}
auto
end
()
const
{
return
ins_map
.
cend
();
}
bool
has_instructions_in
(
const
module
&
mod
)
const
{
return
std
::
all_of
(
ins_map
.
begin
(),
ins_map
.
end
(),
[
&
](
auto
&&
p
)
{
return
mod
.
has_instruction
(
p
.
second
);
});
}
private:
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
ins_map
;
};
instruction_container
instructions
;
instruction_ref
result
;
};
...
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result
.
result
=
ins
;
result
.
instructions
=
ctx
.
instructions
;
assert
(
result
.
instructions
.
has_instructions_in
(
mod
));
}
else
{
...
...
@@ -535,6 +581,18 @@ auto skip_output(Ms... ms)
});
}
inline
auto
var
(
std
::
string
s
)
{
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
it
->
second
;
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
...
...
@@ -698,10 +756,16 @@ auto skip_broadcasts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
}
template
<
class
...
Ms
>
auto
skip_broadcasts_converts
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
return
skip_broadcasts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts
_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
auto
l
=
ins
->
get_literal
();
...
...
src/include/migraphx/memory_coloring.hpp
View file @
2f268bc2
...
...
@@ -17,7 +17,7 @@ struct memory_coloring
std
::
string
allocation_op
{};
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/op/common.hpp
View file @
2f268bc2
...
...
@@ -22,7 +22,8 @@ enum padding_mode_t
enum
class
pooling_mode
{
average
,
max
max
,
lpnorm
};
// indicate rnn computation direction
...
...
src/include/migraphx/op/gathernd.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
gathernd
{
int
batch_dims
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
batch_dims
,
"batch_dims"
));
}
std
::
string
name
()
const
{
return
"gathernd"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
back
().
lens
().
size
();
auto
k
=
inputs
.
back
().
lens
().
back
();
if
(
k
>
r
-
batch_dims
)
{
MIGRAPHX_THROW
(
"GATHERND: Indices of length "
+
std
::
to_string
(
k
)
+
" cannot be used to access data of rank "
+
std
::
to_string
(
r
-
batch_dims
));
}
auto
indices_lens_iter
=
inputs
.
back
().
lens
().
begin
();
auto
output_lens_size
=
q
+
r
-
k
-
batch_dims
-
1
;
std
::
vector
<
std
::
size_t
>
output_lens
(
output_lens_size
);
std
::
copy
(
indices_lens_iter
,
indices_lens_iter
+
(
q
-
1
),
output_lens
.
begin
());
if
(
k
<
r
-
batch_dims
)
{
auto
data_lens
=
inputs
.
front
().
lens
();
std
::
copy
(
data_lens
.
begin
()
+
batch_dims
+
k
,
data_lens
.
end
(),
output_lens
.
begin
()
+
q
-
1
);
}
shape
output_shape
{
inputs
.
front
().
type
(),
output_lens
};
return
output_shape
;
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
indices_shape
=
indices
.
get_shape
();
auto
indices_shape_lens
=
indices_shape
.
lens
();
auto
data_shape
=
data
.
get_shape
();
auto
data_shape_lens
=
data_shape
.
lens
();
auto
k
=
indices_shape
.
lens
().
back
();
const
auto
num_slice_dims
=
k
;
std
::
size_t
num_slices
=
std
::
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
slice_size
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
k
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
num_batches
=
std
::
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
data_batch_stride
=
std
::
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
std
::
vector
<
std
::
size_t
>
sizes_from_slice_dims
(
num_slice_dims
);
{
auto
running_product
=
slice_size
;
for
(
std
::
size_t
i
=
0
;
i
<
num_slice_dims
;
++
i
)
{
sizes_from_slice_dims
[
num_slice_dims
-
1
-
i
]
=
running_product
;
running_product
*=
data_shape_lens
[
batch_dims
+
num_slice_dims
-
1
-
i
];
}
}
std
::
vector
<
std
::
size_t
>
input_slice_offsets
(
num_slices
);
par_for
(
num_slices
,
[
&
](
const
auto
i
)
{
std
::
size_t
batch_idx
=
i
/
num_slices_per_batch
;
auto
slice_indices
=
indices
.
begin
()
+
(
i
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
size_t
dim_idx
=
0
;
dim_idx
<
num_slice_dims
;
++
dim_idx
)
{
int64_t
index
=
*
(
slice_indices
+
dim_idx
);
const
std
::
size_t
input_dim_idx
=
batch_dims
+
dim_idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
if
(
index
<
-
static_cast
<
int64_t
>
(
input_dim
)
or
index
>=
static_cast
<
int64_t
>
(
input_dim
))
MIGRAPHX_THROW
(
"GatherND: index "
+
std
::
to_string
(
index
)
+
" is out of bounds for dim of len "
+
std
::
to_string
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
relative_slice_offset
+=
index
*
sizes_from_slice_dims
[
dim_idx
];
}
input_slice_offsets
[
i
]
=
(
batch_idx
*
data_batch_stride
)
+
relative_slice_offset
;
});
par_for
(
num_slices
*
slice_size
,
[
&
](
const
auto
i
)
{
auto
slice_offset
=
input_slice_offsets
[
i
/
slice_size
];
output
[
i
]
=
data
[
slice_offset
+
i
%
slice_size
];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/pooling.hpp
View file @
2f268bc2
...
...
@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
...
...
@@ -27,6 +28,7 @@ struct pooling
std
::
vector
<
std
::
size_t
>
stride
=
{
1
,
1
};
std
::
vector
<
std
::
size_t
>
lengths
=
{
1
,
1
};
bool
ceil_mode
=
false
;
int
lp_order
=
2
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -35,7 +37,8 @@ struct pooling
f
(
self
.
padding
,
"padding"
),
f
(
self
.
stride
,
"stride"
),
f
(
self
.
lengths
,
"lengths"
),
f
(
self
.
ceil_mode
,
"ceil_mode"
));
f
(
self
.
ceil_mode
,
"ceil_mode"
),
f
(
self
.
lp_order
,
"lp_order"
));
}
std
::
string
name
()
const
{
return
"pooling"
;
}
...
...
@@ -89,6 +92,114 @@ struct pooling
check_attribute_size
();
return
stride
.
size
();
}
struct
lpnorm_pool
{
int
p
=
0
;
lpnorm_pool
()
=
delete
;
explicit
lpnorm_pool
(
int
x
)
:
p
{
x
}
{};
template
<
class
T
>
double
init
()
const
{
return
0.0
;
}
double
operator
()(
double
x
,
double
y
)
const
{
return
x
+
std
::
pow
(
std
::
abs
(
y
),
p
);
}
double
final
(
double
x
,
std
::
size_t
)
const
{
return
std
::
pow
(
x
,
1.
/
p
);
}
};
struct
avg_pool
{
template
<
class
T
>
double
init
()
const
{
return
0.0
;
}
double
operator
()(
double
x
,
double
y
)
const
{
return
x
+
y
;
}
double
final
(
double
x
,
std
::
size_t
y
)
const
{
return
(
y
==
0
)
?
0.0
:
(
x
/
y
);
}
};
struct
max_pool
{
template
<
class
T
>
T
init
()
const
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
double
operator
()(
double
x
,
double
y
)
const
{
return
std
::
max
(
x
,
y
);
}
double
final
(
double
x
,
std
::
size_t
)
const
{
return
(
x
);
}
};
template
<
class
Type
,
class
Out
,
class
In
,
class
Op
>
void
calc_pooling
(
const
shape
&
output_shape
,
Out
&
output
,
const
In
&
input
,
Op
op
)
const
{
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx_o
=
output_shape
.
multi
(
i
);
auto
n_dim
=
idx_o
.
size
();
std
::
vector
<
std
::
size_t
>
win_start
;
std
::
vector
<
std
::
size_t
>
win_size
;
for
(
std
::
size_t
dim
=
2
;
dim
<
n_dim
;
++
dim
)
{
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
stride
[
d_2
])
-
static_cast
<
int
>
(
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
lengths
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
}
shape
win_shape
{
output_shape
.
type
(),
win_size
};
auto
pool_size
=
win_shape
.
elements
();
double
output_val
=
op
.
template
init
<
Type
>();
shape_for_each
(
win_shape
,
[
&
](
auto
idx_w
)
{
auto
idx
=
idx_o
;
std
::
transform
(
idx_w
.
begin
(),
idx_w
.
end
(),
win_start
.
begin
(),
idx
.
begin
()
+
2
,
[](
auto
ii
,
auto
jj
)
{
return
ii
+
jj
;
});
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
idx
<
in_lens
)
{
output_val
=
op
(
output_val
,
input
[
in_s
.
index
(
idx
)]);
}
});
output
[
i
]
=
Type
(
op
.
final
(
output_val
,
pool_size
));
});
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
switch
(
mode
)
{
case
migraphx
::
op
::
pooling_mode
::
average
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
avg_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
max
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
max_pool
{});
break
;
case
migraphx
::
op
::
pooling_mode
::
lpnorm
:
calc_pooling
<
type
>
(
output_shape
,
output
,
input
,
lpnorm_pool
{
lp_order
});
break
;
}
});
return
result
;
}
};
}
// namespace op
...
...
src/include/migraphx/op/scatter.hpp
View file @
2f268bc2
...
...
@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
...
...
@@ -16,7 +17,17 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter
// The scatter operator fetches a subset of data given by an index array and then performs a
// reduction operation (add, multiply, or just set the data) on each element returned. We implement
// it as a separate derived struct for each of the three reduction methods. The related operator
// scatterND is a generalization that works on a set of 3 tensors of different ranks. The
// complementary operations are gather/gatherND.
//
// This is a template for deriving child structs from. Each child needs to define
// only a reduction() method. Names are automatically handled by the op_name template.
template
<
class
Derived
>
struct
scatter
:
op_name
<
Derived
>
{
int64_t
axis
=
0
;
...
...
@@ -33,29 +44,44 @@ struct scatter
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"scatter"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
standard
();
return
inputs
.
front
();
// If non-packed, this converts to a packed output while preserving permutation of tensor
return
inputs
.
front
().
with_lens
(
inputs
.
front
().
lens
());
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// max dimension in axis
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
// max dimension in each axis
auto
axis_dim_size
=
output_shape
.
lens
()[
axis
];
// cast all arguments as correct type
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
update
)
{
// copy all of data to output
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
ind_s
=
indices
.
get_shape
();
// iterate through items in shape
shape_for_each
(
ind_s
,
[
&
](
const
auto
&
idx
)
{
auto
out_idx
=
idx
;
auto
index
=
indices
[
ind_s
.
index
(
idx
)];
auto
out_idx
=
idx
;
// Overloaded tensor_view::() invokes indexing logic of
// std::size_t shape::index(std::size_t i) const
// which handles nonstandard shapes correctly
auto
index
=
indices
(
idx
.
begin
(),
idx
.
end
());
// normalize negative indexes (may be redundant after using
// normalize_compute_shape())
index
=
(
index
<
0
)
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
output_shape
.
index
(
out_idx
)]
=
update
[
ind_s
.
index
(
idx
)];
// look up the appropriate locations in output, using idx and out_idx.
// call reduction() method of derived struct to copy and reduce that element
self
.
reduction
()(
output
(
out_idx
.
begin
(),
out_idx
.
end
()),
update
(
idx
.
begin
(),
idx
.
end
()));
});
});
});
...
...
src/include/migraphx/op/scatter_add.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_add
:
scatter
<
scatter_add
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
+=
y
;
};
}
// name of this struct is automatically assigned by the op_name<>
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatter_mul.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_mul
:
scatter
<
scatter_mul
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
*=
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatter_none.hpp
0 → 100644
View file @
2f268bc2
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter_none
:
scatter
<
scatter_none
>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
y
;
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
2f268bc2
...
...
@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
...
...
@@ -86,7 +87,9 @@
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
...
...
src/include/migraphx/optional.hpp
View file @
2f268bc2
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
...
...
src/include/migraphx/propagate_constant.hpp
View file @
2f268bc2
...
...
@@ -15,7 +15,7 @@ struct module;
struct
propagate_constant
{
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/raw_data.hpp
View file @
2f268bc2
...
...
@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1)
template
<
class
T
,
class
...
Ts
>
auto
visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
auto
&&
s
=
x
.
get_shape
();
// cppcheck-suppress redundantInitialization
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
...
...
src/include/migraphx/rewrite_batchnorm.hpp
View file @
2f268bc2
...
...
@@ -16,7 +16,7 @@ struct module;
struct
rewrite_batchnorm
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment