Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5ee35bf0
"vscode:/vscode.git/clone" did not exist on "4df8b7a2cec4f44da220b9a2ae5b40dbcd3ef288"
Commit
5ee35bf0
authored
Nov 14, 2023
by
Paul
Browse files
Format
parent
d1dc3e35
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
147 additions
and
106 deletions
+147
-106
src/fuse_concat.cpp
src/fuse_concat.cpp
+60
-41
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+1
-0
src/module.cpp
src/module.cpp
+1
-4
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+49
-35
test/fuse_concat.cpp
test/fuse_concat.cpp
+36
-26
No files found.
src/fuse_concat.cpp
View file @
5ee35bf0
...
...
@@ -27,29 +27,37 @@ struct fused_concat
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
((
inputs
.
size
()
+
1
)
==
mods
.
size
())
if
((
inputs
.
size
()
+
1
)
==
mods
.
size
())
MIGRAPHX_THROW
(
"FUSED_CONCAT: Missing fused modules"
);
auto
input_iter
=
inputs
.
begin
();
std
::
vector
<
shape
>
concat_inputs
;
for
(
module_ref
mod
:
range
(
mods
.
begin
(),
mods
.
end
()
-
1
))
for
(
module_ref
mod
:
range
(
mods
.
begin
(),
mods
.
end
()
-
1
))
{
concat_inputs
.
push_back
(
*
input_iter
);
input_iter
+=
mod
->
get_parameter_names
().
size
();
}
module_ref
post_mod
=
mods
.
back
();
auto
type
=
std
::
prev
(
post_mod
->
end
())
->
get_shape
().
type
();
module_ref
post_mod
=
mods
.
back
();
auto
type
=
std
::
prev
(
post_mod
->
end
())
->
get_shape
().
type
();
const
auto
&
first_shape_lens
=
concat_inputs
.
front
().
lens
();
if
(
not
std
::
all_of
(
concat_inputs
.
begin
()
+
1
,
concat_inputs
.
end
(),
[
&
](
auto
s
)
{
const
auto
&
lens
=
s
.
lens
();
return
std
::
equal
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
first_shape_lens
.
begin
(),
first_shape_lens
.
begin
()
+
axis
)
and
std
::
equal
(
lens
.
begin
()
+
axis
+
1
,
lens
.
end
(),
first_shape_lens
.
begin
()
+
axis
+
1
,
first_shape_lens
.
end
());
}))
MIGRAPHX_THROW
(
"FUSED_CONCAT: all input dimensions should match along non-axis: "
+
std
::
to_string
(
axis
));
std
::
size_t
new_dim_axis
=
transform_accumulate
(
concat_inputs
.
begin
(),
concat_inputs
.
end
(),
0
,
std
::
plus
<>
{},
[
&
](
const
auto
&
input
)
{
return
input
.
lens
()[
axis
];
});
auto
new_lens
=
concat_inputs
.
front
().
lens
();
if
(
not
std
::
all_of
(
concat_inputs
.
begin
()
+
1
,
concat_inputs
.
end
(),
[
&
](
auto
s
)
{
const
auto
&
lens
=
s
.
lens
();
return
std
::
equal
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
first_shape_lens
.
begin
(),
first_shape_lens
.
begin
()
+
axis
)
and
std
::
equal
(
lens
.
begin
()
+
axis
+
1
,
lens
.
end
(),
first_shape_lens
.
begin
()
+
axis
+
1
,
first_shape_lens
.
end
());
}))
MIGRAPHX_THROW
(
"FUSED_CONCAT: all input dimensions should match along non-axis: "
+
std
::
to_string
(
axis
));
std
::
size_t
new_dim_axis
=
transform_accumulate
(
concat_inputs
.
begin
(),
concat_inputs
.
end
(),
0
,
std
::
plus
<>
{},
[
&
](
const
auto
&
input
)
{
return
input
.
lens
()[
axis
];
});
auto
new_lens
=
concat_inputs
.
front
().
lens
();
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
...
...
@@ -63,55 +71,66 @@ struct find_pointwise_concat_pointwise
{
auto
matcher
()
const
{
auto
concat
=
match
::
name
(
"concat"
)(
match
::
used_once
(),
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
())));
auto
concat
=
match
::
name
(
"concat"
)(
match
::
used_once
(),
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
())));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
concat
.
bind
(
"concat"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
concat_ins
=
r
.
instructions
[
"concat"
];
auto
concat_arg
=
std
::
find
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
concat_ins
)
-
ins
->
inputs
().
begin
();
auto
concat_arg
=
std
::
find
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
concat_ins
)
-
ins
->
inputs
().
begin
();
std
::
vector
<
instruction_ref
>
inputs
;
for
(
auto
input
:
concat_ins
->
inputs
())
for
(
auto
input
:
concat_ins
->
inputs
())
inputs
.
insert
(
inputs
.
end
(),
input
->
inputs
().
begin
(),
input
->
inputs
().
end
());
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
concat_ins
;
});
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
concat_ins
;
});
std
::
vector
<
module_ref
>
module_inputs
;
std
::
transform
(
concat_ins
->
inputs
().
begin
(),
concat_ins
->
inputs
().
end
(),
std
::
back_inserter
(
module_inputs
),
[
&
](
instruction_ref
input
)
{
if
(
input
->
name
()
==
"pointwise"
)
{
auto
*
pm
=
input
->
module_inputs
().
front
();
return
mpm
.
create_module
(
"concat:"
+
pm
->
name
(),
*
pm
);
}
auto
*
pm
=
mpm
.
create_module
(
"concat"
+
std
::
to_string
(
counter
++
));
auto
x
=
pm
->
add_parameter
(
"x"
,
shape
{
input
->
get_shape
().
type
()});
auto
id
=
pm
->
add_instruction
(
make_op
(
"identity"
),
x
);
pm
->
add_return
({
id
});
return
pm
;
});
auto
*
post_pm
=
ins
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
post_pm
->
name
()
+
":concat"
,
*
post_pm
);
std
::
transform
(
concat_ins
->
inputs
().
begin
(),
concat_ins
->
inputs
().
end
(),
std
::
back_inserter
(
module_inputs
),
[
&
](
instruction_ref
input
)
{
if
(
input
->
name
()
==
"pointwise"
)
{
auto
*
pm
=
input
->
module_inputs
().
front
();
return
mpm
.
create_module
(
"concat:"
+
pm
->
name
(),
*
pm
);
}
auto
*
pm
=
mpm
.
create_module
(
"concat"
+
std
::
to_string
(
counter
++
));
auto
x
=
pm
->
add_parameter
(
"x"
,
shape
{
input
->
get_shape
().
type
()});
auto
id
=
pm
->
add_instruction
(
make_op
(
"identity"
),
x
);
pm
->
add_return
({
id
});
return
pm
;
});
auto
*
post_pm
=
ins
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
post_pm
->
name
()
+
":concat"
,
*
post_pm
);
std
::
vector
<
std
::
string
>
names
=
rm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
concat_param_name
=
names
[
concat_arg
];
auto
concat_param
=
rm
->
get_parameter
(
concat_param_name
);
auto
concat_param
=
rm
->
get_parameter
(
concat_param_name
);
auto
param
=
rm
->
add_parameter
(
"!"
+
concat_param_name
,
concat_param
->
get_shape
());
rm
->
replace_instruction
(
concat_param
,
param
);
rm
->
remove_instruction
(
concat_param
);
module_inputs
.
push_back
(
rm
);
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_concat"
,
concat_ins
->
normalized_operator
().
to_value
()),
inputs
,
module_inputs
);
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_concat"
,
concat_ins
->
normalized_operator
().
to_value
()),
inputs
,
module_inputs
);
}
};
}
}
// namespace
void
fuse_concat
::
apply
(
module_pass_manager
&
mpm
)
const
{
...
...
src/include/migraphx/module.hpp
View file @
5ee35bf0
...
...
@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
friend
struct
program
;
private:
void
set_name
(
const
std
::
string
&
name
);
void
assign
(
const
module
&
m
);
...
...
src/module.cpp
View file @
5ee35bf0
...
...
@@ -134,10 +134,7 @@ module& module::operator=(module m)
std
::
string
module
::
name
()
const
{
return
impl
->
name
;
}
void
module
::
set_name
(
const
std
::
string
&
name
)
{
impl
->
name
=
name
;
}
void
module
::
set_name
(
const
std
::
string
&
name
)
{
impl
->
name
=
name
;
}
bool
module
::
bypass
()
const
{
return
impl
->
bypass
;
}
void
module
::
set_bypass
(
bool
b
)
{
impl
->
bypass
=
b
;
}
...
...
src/targets/gpu/jit/concat.cpp
View file @
5ee35bf0
...
...
@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"
;
struct
concat_compiler
:
compiler
<
concat_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"concat"
};
}
...
...
@@ -152,19 +151,18 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
op_names
=
v
.
at
(
"ops"
).
to_vector
<
std
::
string
>
();
auto
args
=
v
.
at
(
"args"
);
auto
op_names
=
v
.
at
(
"ops"
).
to_vector
<
std
::
string
>
();
auto
args
=
v
.
at
(
"args"
);
vectorize
vec
{};
if
(
axis
!=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
())
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
auto
nelements_per_op
=
options
.
inputs
.
back
().
elements
()
/
op_names
.
size
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_args
;
for
(
const
auto
&
name
:
op_names
)
for
(
const
auto
&
name
:
op_names
)
{
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
prefix
=
name
+
"_concat_x"
;
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
i
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
i
);
...
...
@@ -175,17 +173,16 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
});
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
}
auto
src
=
interpolate_string
(
fused_concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
join_strings
(
concat_params
,
", "
)},
{
"concat_args"
,
join_strings
(
concat_args
,
", "
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
auto
src
=
interpolate_string
(
fused_concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
join_strings
(
concat_params
,
", "
)},
{
"concat_args"
,
join_strings
(
concat_args
,
", "
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
}
...
...
@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
{
auto
v
=
op
.
to_value
();
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
transform
(
range
(
ins
->
module_inputs
().
size
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
[
&
](
auto
i
)
{
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
"pointwise"
+
std
::
to_string
(
i
));
});
v
[
"preamble"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
});
transform
(
range
(
ins
->
module_inputs
().
size
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
[
&
](
auto
i
)
{
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
"pointwise"
+
std
::
to_string
(
i
));
});
v
[
"preamble"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
});
std
::
vector
<
std
::
string
>
mod_names
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
back_inserter
(
mod_names
),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
back_inserter
(
mod_names
),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
[
&
](
module_ref
mod
)
{
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
});
v
[
"args"
]
=
mod_args
;
v
[
"kernel"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_name_from_ops
(
*
mod
)
+
"_"
;
})
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
[
&
](
module_ref
mod
)
{
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
});
v
[
"args"
]
=
mod_args
;
v
[
"kernel"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_name_from_ops
(
*
mod
)
+
"_"
;
})
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
};
...
...
test/fuse_concat.cpp
View file @
5ee35bf0
...
...
@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p)
migraphx
::
run_passes
(
p
,
{
migraphx
::
fuse_concat
{},
migraphx
::
dead_code_elimination
{}});
}
template
<
class
F
>
template
<
class
F
>
struct
concat_arg
{
std
::
string
name
;
...
...
@@ -45,35 +45,40 @@ struct concat_arg
F
f
;
};
template
<
class
F
>
template
<
class
F
>
concat_arg
<
F
>
arg
(
std
::
string
name
,
std
::
vector
<
migraphx
::
instruction_ref
>
inputs
,
F
f
)
{
return
{
std
::
move
(
name
),
std
::
move
(
inputs
),
std
::
move
(
f
)};
}
template
<
class
Arg
,
class
...
Args
>
migraphx
::
instruction_ref
add_concat
(
migraphx
::
program
&
p
,
std
::
size_t
axis
,
Arg
post_arg
,
Args
...
args
)
migraphx
::
instruction_ref
add_concat
(
migraphx
::
program
&
p
,
std
::
size_t
axis
,
Arg
post_arg
,
Args
...
args
)
{
std
::
vector
<
migraphx
::
module_ref
>
module_inputs
;
std
::
vector
<
migraphx
::
instruction_ref
>
ins_inputs
;
migraphx
::
each_args
([
&
](
auto
arg
)
{
module_inputs
.
push_back
(
create_pointwise_module
(
p
,
arg
.
name
,
arg
.
inputs
,
arg
.
f
));
ins_inputs
.
insert
(
ins_inputs
.
end
(),
arg
.
inputs
.
begin
(),
arg
.
inputs
.
end
());
},
args
...);
migraphx
::
each_args
(
[
&
](
auto
arg
)
{
module_inputs
.
push_back
(
create_pointwise_module
(
p
,
arg
.
name
,
arg
.
inputs
,
arg
.
f
));
ins_inputs
.
insert
(
ins_inputs
.
end
(),
arg
.
inputs
.
begin
(),
arg
.
inputs
.
end
());
},
args
...);
module_inputs
.
push_back
(
create_pointwise_module
(
p
,
post_arg
.
name
,
{},
[
&
](
auto
*
pm
,
auto
&&
)
{
std
::
vector
<
migraphx
::
instruction_ref
>
params
;
params
.
push_back
(
pm
->
add_parameter
(
"!x0"
,
migraphx
::
shape
{
ins_inputs
.
back
()
->
get_shape
().
type
()}));
std
::
transform
(
post_arg
.
inputs
.
begin
(),
post_arg
.
inputs
.
end
(),
std
::
back_inserter
(
params
),
[
&
](
auto
input
)
{
return
pm
->
add_parameter
(
"x"
+
std
::
to_string
(
params
.
size
()),
migraphx
::
shape
{
input
->
get_shape
().
type
()});
});
params
.
push_back
(
pm
->
add_parameter
(
"!x0"
,
migraphx
::
shape
{
ins_inputs
.
back
()
->
get_shape
().
type
()}));
std
::
transform
(
post_arg
.
inputs
.
begin
(),
post_arg
.
inputs
.
end
(),
std
::
back_inserter
(
params
),
[
&
](
auto
input
)
{
return
pm
->
add_parameter
(
"x"
+
std
::
to_string
(
params
.
size
()),
migraphx
::
shape
{
input
->
get_shape
().
type
()});
});
return
post_arg
.
f
(
pm
,
params
);
}));
auto
*
mm
=
p
.
get_main_module
();
return
mm
->
add_instruction
(
migraphx
::
make_op
(
"fused_concat"
,
{{
"axis"
,
axis
}}),
ins_inputs
,
module_inputs
);
return
mm
->
add_instruction
(
migraphx
::
make_op
(
"fused_concat"
,
{{
"axis"
,
axis
}}),
ins_inputs
,
module_inputs
);
}
TEST_CASE
(
simple_pointwise_concat
)
...
...
@@ -81,22 +86,27 @@ TEST_CASE(simple_pointwise_concat)
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
add
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
));
auto
sub
=
add_pointwise
(
p1
,
"main:pointwise1"
,
{
x
,
y
},
single_pointwise
(
"sub"
));
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
add
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
));
auto
sub
=
add_pointwise
(
p1
,
"main:pointwise1"
,
{
x
,
y
},
single_pointwise
(
"sub"
));
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
add
,
sub
);
auto
relu
=
add_pointwise
(
p1
,
"main:pointwise2"
,
{
concat
},
single_pointwise
(
"relu"
));
auto
relu
=
add_pointwise
(
p1
,
"main:pointwise2"
,
{
concat
},
single_pointwise
(
"relu"
));
mm
->
add_return
({
relu
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
fused_concat
=
add_concat
(
p2
,
1
,
arg
(
"main:pointwise2:concat"
,
{},
single_pointwise
(
"relu"
)),
arg
(
"concat:main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
)),
arg
(
"concat:main:pointwise1"
,
{
x
,
y
},
single_pointwise
(
"sub"
)));
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
fused_concat
=
add_concat
(
p2
,
1
,
arg
(
"main:pointwise2:concat"
,
{},
single_pointwise
(
"relu"
)),
arg
(
"concat:main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
)),
arg
(
"concat:main:pointwise1"
,
{
x
,
y
},
single_pointwise
(
"sub"
)));
mm
->
add_return
({
fused_concat
});
}
EXPECT
(
p1
==
p2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment