Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e7268884
Commit
e7268884
authored
Aug 30, 2022
by
Khalique Ahmed
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into layernorm_eps
parents
57e3784f
ed7973d1
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
641 additions
and
228 deletions
+641
-228
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+6
-2
src/include/migraphx/target_assignments.hpp
src/include/migraphx/target_assignments.hpp
+1
-0
src/pad_calc.cpp
src/pad_calc.cpp
+1
-1
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+73
-25
src/targets/gpu/code_object_op.cpp
src/targets/gpu/code_object_op.cpp
+2
-1
src/targets/gpu/driver/compile_op.cpp
src/targets/gpu/driver/compile_op.cpp
+5
-2
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
+2
-1
src/targets/gpu/driver/perf.cpp
src/targets/gpu/driver/perf.cpp
+19
-10
src/targets/gpu/driver/run_op.cpp
src/targets/gpu/driver/run_op.cpp
+2
-2
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+138
-18
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+25
-8
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+48
-0
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+13
-8
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/kernel.hpp
src/targets/gpu/include/migraphx/gpu/kernel.hpp
+9
-4
src/targets/gpu/kernel.cpp
src/targets/gpu/kernel.cpp
+30
-7
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-2
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+3
-2
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+211
-135
test/verify/gemm_add_broadcast1.cpp
test/verify/gemm_add_broadcast1.cpp
+49
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
e7268884
...
...
@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return
nullopt
;
}
MIGRAPHX_PRED_MATCHER
(
broadcast
,
instruction_ref
ins
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
ins
->
name
());
}
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
{
...
...
@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template
<
class
...
Ms
>
auto
pointwise
(
Ms
...
ms
)
{
return
match
::
has_attribute
(
"pointwise"
)(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
)),
ms
...);
return
match
::
has_attribute
(
"pointwise"
)(
ms
...);
}
}
// namespace match
...
...
src/include/migraphx/target_assignments.hpp
View file @
e7268884
...
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <string>
#include <migraphx/instruction_ref.hpp>
...
...
src/pad_calc.cpp
View file @
e7268884
...
...
@@ -60,7 +60,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
{
std
::
vector
<
std
::
size_t
>
padding
;
padding
.
resize
(
2
*
k_lens
.
size
());
for
(
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
padding
.
size
()
/
2
;
i
++
)
{
std
::
ptrdiff_t
input_dim
=
tensor_lens
[
i
];
std
::
ptrdiff_t
stride
=
strides
[
i
];
...
...
src/simplify_algebra.cpp
View file @
e7268884
...
...
@@ -208,6 +208,42 @@ struct find_mul_add
}
};
struct
find_dot_add
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
const
bool
flipped
=
a_ins
==
ins
->
inputs
().
back
();
auto
insert_dot
=
[
&
](
auto
x
,
auto
y
)
{
if
(
flipped
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
y
,
x
);
else
return
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
x
,
y
);
};
auto
ax_ins
=
insert_dot
(
a_ins
,
x_ins
);
auto
ab_ins
=
insert_dot
(
a_ins
,
b_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
ax_ins
,
ab_ins
);
}
};
struct
find_add_lit_broadcast
{
auto
matcher
()
const
...
...
@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct
find_inner_broadcast
{
auto
matcher
()
const
{
return
pointwise
(
match
::
nargs
(
2
),
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
y_ins
=
r
.
instructions
[
"y"
];
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
auto
ins
=
r
.
result
;
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
();
}))
return
;
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
...
...
@@ -416,8 +450,9 @@ struct find_splits
{
auto
matcher
()
const
{
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(),
reduction
()))));
return
match
::
any
(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)(
match
::
any_of
[
match
::
outputs
()](
match
::
pointwise
(
match
::
any_of
(
match
::
nargs
(
1
),
match
::
nargs
(
2
))),
reduction
()))));
}
static
bool
is_dependent
(
const
module
&
m
,
instruction_ref
ins1
,
instruction_ref
ins2
)
...
...
@@ -580,10 +615,9 @@ struct find_splits
auto
outputs
=
i
->
outputs
();
for
(
auto
output
:
outputs
)
{
if
(
not
contains
({
"reshape"
,
"squeeze"
,
"unsqueeze"
},
output
->
name
()
)
)
if
(
output
->
name
()
!=
"reshape"
)
continue
;
auto
x
=
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
output
->
inputs
());
auto
x
=
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
i
);
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
...
...
@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
if
(
x
.
size
()
!=
y
.
size
())
return
false
;
// Check that non-ax
is
es match
// Check that non-axes match
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
{
...
...
@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion
for
(
auto
arg
:
args
)
m
.
move_instructions
(
arg
,
input
);
// TODO: Check if ax
is
es match
// TODO: Check if axes match
auto
concat
=
m
.
insert_instruction
(
input
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
args
);
auto
fused
=
m
.
insert_instruction
(
std
::
next
(
input
),
op
,
input
,
concat
);
int64_t
offset
=
0
;
for
(
auto
arg
:
range
(
start
,
last
))
{
auto
outputs
=
arg
->
outputs
();
for
(
auto
output
:
outputs
)
{
if
(
output
->
name
()
!=
"reshape"
)
continue
;
auto
x
=
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
arg
);
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
m
.
replace_instruction
(
arg
,
...
...
@@ -958,7 +1001,11 @@ struct find_split_reshape
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// insert the reshape instruction
// insert the reshape instruction and add contiguous if needed
if
(
not
input
->
get_shape
().
standard
())
{
input
=
m
.
insert_instruction
(
std
::
next
(
input
),
make_op
(
"contiguous"
),
input
);
}
auto
rsp_ins
=
m
.
insert_instruction
(
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
...
...
@@ -1048,6 +1095,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_add
{},
find_dot_add
{},
find_div_const
{},
find_sub_const
{},
find_rsqrt
{},
...
...
src/targets/gpu/code_object_op.cpp
View file @
e7268884
...
...
@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std
::
vector
<
void
*>
kargs
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
kargs
.
begin
(),
[](
const
argument
&
a
)
{
return
a
.
data
();
});
k
.
launch
(
ctx
.
get_stream
().
get
(),
global
,
local
,
std
::
move
(
kargs
));
auto
[
start
,
stop
]
=
ctx
.
get_perf_events
();
k
.
launch
(
ctx
.
get_stream
().
get
(),
global
,
local
,
std
::
move
(
kargs
),
start
,
stop
);
return
args
[
get_output_arg
(
args
.
size
())];
}
void
code_object_op
::
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
...
...
src/targets/gpu/driver/compile_op.cpp
View file @
e7268884
...
...
@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
op
=
gpu
::
compile_op
(
v
.
at
(
"name"
).
to
<
std
::
string
>
(),
ctx
,
inputs
,
v
);
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
;
if
(
device_time
>
0
)
std
::
cout
<<
", "
<<
device_time
<<
"ms"
;
std
::
cout
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp
View file @
e7268884
...
...
@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
driver
{
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
}
// namespace driver
}
// namespace gpu
...
...
src/targets/gpu/driver/perf.cpp
View file @
e7268884
...
...
@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
}
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
double
time_op
(
context
&
ctx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
)
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
)
{
// TODO: Use std::ref
migraphx
::
context
gctx
=
ctx
;
auto
output
=
op
.
compute_shape
(
inputs
);
op
.
finalize
(
gctx
,
output
,
inputs
);
migraphx
::
context
ctx
=
ictx
;
auto
&
gctx
=
any_cast
<
migraphx
::
gpu
::
context
>
(
ctx
);
auto
output
=
op
.
compute_shape
(
inputs
);
op
.
finalize
(
ctx
,
output
,
inputs
);
auto
args
=
generate_arguments
(
inputs
);
auto
run
=
[
&
]
{
op
.
compute
(
g
ctx
,
output
,
args
);
g
ctx
.
finish
();
op
.
compute
(
ctx
,
output
,
args
);
ctx
.
finish
();
};
gctx
.
enable_perf_measurement
();
run
();
auto
r
=
range
(
n
);
double
t
=
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
double
{
0.0
},
[
&
](
auto
x
,
auto
)
{
return
x
+
time
<
milliseconds
>
(
run
);
});
return
t
/
n
;
double
host_time
=
0.0
;
double
device_time
=
0.0
;
for
(
auto
i
:
range
(
n
))
{
(
void
)
i
;
host_time
+=
time
<
milliseconds
>
(
run
);
device_time
+=
gctx
.
get_elapsed_ms
();
}
return
std
::
make_pair
(
host_time
/
n
,
device_time
/
n
);
}
}
// namespace driver
...
...
src/targets/gpu/driver/run_op.cpp
View file @
e7268884
...
...
@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto
op
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
op
.
from_value
(
v
.
at
(
"fields"
));
double
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
auto
[
host_time
,
device_time
]
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
op
<<
": "
<<
host_time
<<
"ms"
<<
std
::
endl
;
}
};
...
...
src/targets/gpu/fuse_ops.cpp
View file @
e7268884
...
...
@@ -48,8 +48,10 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath>
#include <set>
...
...
@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct
hip_triadd_layernorm
:
ternary_device
<
hip_triadd_layernorm
,
&
device
::
triadd_layernorm
>
{
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
inputs
[
0
];
}
// Empty finalize to skip dimension reduction
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
)
{}
};
...
...
@@ -943,28 +950,70 @@ struct find_gemm_add
}
};
auto
pointwise_name
(
const
std
::
string
&
s
)
{
return
precompile_name
(
"pointwise"
)(
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
module_ref
pm
=
ins
->
module_inputs
().
front
();
auto
n
=
std
::
count_if
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
i
.
name
()
==
s
;
});
if
(
n
!=
1
)
return
false
;
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
starts_with
(
i
.
name
(),
"@"
)
or
i
.
name
()
==
s
;
});
}));
}
struct
find_gemm_pointwise
{
auto
matcher
()
const
{
return
p
ointwise_name
(
"add
"
)(
return
p
recompile_name
(
"pointwise
"
)(
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
match
::
either_arg
(
0
,
1
)(
match
::
any_of
(
match
::
standard_shape
(),
match
::
is_constant
()).
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
),
match
::
used_once
()).
bind
(
"gemm"
)));
}
// TODO: Move to matcher.hpp
static
auto
match_param
(
const
std
::
string
&
name
)
{
return
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
if
(
ins
->
name
()
!=
"@param"
)
return
false
;
auto
p
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
());
return
p
.
parameter
==
name
;
});
}
template
<
class
M
>
static
auto
match_mul_const
(
M
m
,
const
std
::
string
&
var
)
{
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
var
),
m
))
.
bind
(
var
+
"_mul"
);
}
static
auto
match_add
(
const
std
::
string
&
input
,
const
std
::
string
&
output
)
{
auto
param
=
match
::
name
(
"@param"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
args
(
param
,
param
));
auto
inner_mul
=
match
::
any_of
(
match_mul_const
(
match_param
(
input
),
"alpha"
),
match_mul_const
(
match_param
(
output
),
"beta"
));
auto
mul_add
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
inner_mul
,
param
));
auto
add_mul
=
match_mul_const
(
add
,
"gamma"
);
return
match
::
name
(
"@return"
)(
match
::
args
(
match
::
any_of
(
add
,
mul_add
,
add_mul
)));
}
static
float
get_float
(
instruction_ref
ins
)
{
return
ins
->
get_literal
().
at
<
float
>
();
}
template
<
class
Gemm
>
static
bool
update_gemm
(
Gemm
&
gemm
,
module_ref
pm
,
unsigned
input
)
{
auto
names
=
pm
->
get_parameter_names
();
if
(
names
.
size
()
!=
2
)
return
false
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
unsigned
output
=
input
==
0
?
1
:
0
;
auto
mr
=
match
::
match_instruction
(
*
pm
,
std
::
prev
(
pm
->
end
()),
match_add
(
names
[
input
],
names
[
output
]));
if
(
mr
.
result
==
pm
->
end
())
return
false
;
if
(
contains
(
mr
.
instructions
,
"alpha_mul"
))
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"alpha"
]);
else
if
(
contains
(
mr
.
instructions
,
"beta_mul"
))
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"beta"
]);
else
if
(
contains
(
mr
.
instructions
,
"gamma_mul"
))
{
gemm
.
alpha
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
gemm
.
beta
*=
get_float
(
mr
.
instructions
[
"gamma"
]);
}
return
true
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -978,6 +1027,19 @@ struct find_gemm_pointwise
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
gemm
.
beta
=
1
;
if
(
not
update_gemm
(
gemm
,
ins
->
module_inputs
().
front
(),
ins
->
inputs
().
front
()
==
gemm_ins
?
0
:
1
))
return
;
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
{
auto
c
=
op
::
contiguous
{};
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
...
...
@@ -985,11 +1047,68 @@ struct find_gemm_pointwise
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
};
struct
find_contiguous_tranpose_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"gpu::gemm"
)(
match
::
used_once
()).
bind
(
"gemm"
)))
.
bind
(
"transpose"
)));
}
template
<
class
Vector
>
static
bool
is_swapped
(
const
Vector
&
perm
,
std
::
size_t
i
,
std
::
size_t
j
)
{
if
(
i
>=
perm
.
size
()
or
j
>=
perm
.
size
())
return
false
;
auto
perm2
=
perm
;
std
::
iota
(
perm2
.
begin
(),
perm2
.
end
(),
0
);
std
::
swap
(
perm2
[
i
],
perm2
[
j
]);
return
perm2
==
perm
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm
=
r
.
instructions
[
"gemm"
];
auto
alloc
=
gemm
->
inputs
().
back
();
auto
transpose
=
r
.
instructions
[
"transpose"
];
auto
perm
=
transpose
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
iperm
=
invert_permutation
(
perm
);
if
(
perm
.
size
()
<
3
)
return
;
if
(
not
is_swapped
(
perm
,
perm
.
size
()
-
3
,
perm
.
size
()
-
2
))
return
;
auto
lens
=
gemm
->
get_shape
().
lens
();
if
(
lens
.
size
()
>
3
and
not
std
::
all_of
(
lens
.
begin
(),
lens
.
end
()
-
3
,
[](
auto
i
)
{
return
i
==
1
;
}))
return
;
auto
gemmv
=
gemm
->
get_operator
().
to_value
();
gemmv
[
"trans_batch"
]
=
1
;
auto
s
=
shape
{
alloc
->
get_shape
().
type
(),
reorder_dims
(
alloc
->
get_shape
().
lens
(),
iperm
)};
auto
new_alloc
=
m
.
insert_instruction
(
gemm
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
s
)}}));
auto
alloc_transpose
=
m
.
insert_instruction
(
gemm
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
new_alloc
);
auto
inputs
=
gemm
->
inputs
();
inputs
.
back
()
=
alloc_transpose
;
auto
new_gemm
=
m
.
insert_instruction
(
gemm
,
make_op
(
"gpu::gemm"
,
gemmv
),
inputs
);
auto
gemm_transpoe
=
m
.
insert_instruction
(
gemm
,
transpose
->
get_operator
(),
new_gemm
);
m
.
replace_instruction
(
ins
,
gemm_transpoe
);
}
};
struct
find_commutative_broadcast
{
auto
matcher
()
const
...
...
@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
e7268884
...
...
@@ -24,6 +24,7 @@
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
)
{
if
(
trans_batch
==
0
)
return
s
;
if
(
s
.
lens
().
size
()
<
3
)
return
s
;
auto
batch
=
s
.
lens
().
size
()
-
3
;
std
::
vector
<
int64_t
>
perm
(
s
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
batch
],
perm
[
batch
+
trans_batch
]);
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
...
...
@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
compute_fp32
)
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
...
...
@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
{
beta
=
0
;
}
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
...
...
@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ld
c
,
ld
d
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
...
...
@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ld
c
,
c
_stride
,
ld
d
,
d
_stride
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
e7268884
...
...
@@ -244,6 +244,15 @@ struct context
return
hip_event_ptr
{
event
};
}
static
hip_event_ptr
create_event_for_timing
()
{
hipEvent_t
event
;
auto
status
=
hipEventCreate
(
&
event
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to create event"
);
return
hip_event_ptr
{
event
};
}
value
to_value
()
const
{
value
result
;
...
...
@@ -267,10 +276,49 @@ struct context
any_ptr
get_queue
()
{
return
get_stream
().
get
();
}
void
enable_perf_measurement
(
bool
b
=
true
)
{
if
(
b
)
{
start_event
=
create_event_for_timing
();
stop_event
=
create_event_for_timing
();
get_stream
().
record
(
start_event
.
get
());
get_stream
().
record
(
stop_event
.
get
());
}
else
{
start_event
=
nullptr
;
stop_event
=
nullptr
;
}
measure_perf
=
b
;
}
std
::
pair
<
hipEvent_t
,
hipEvent_t
>
get_perf_events
()
const
{
if
(
measure_perf
)
return
std
::
make_pair
(
start_event
.
get
(),
stop_event
.
get
());
return
std
::
make_pair
(
nullptr
,
nullptr
);
}
float
get_elapsed_ms
()
const
{
float
result
=
0
;
if
(
start_event
!=
nullptr
and
stop_event
!=
nullptr
)
{
auto
status
=
hipEventElapsedTime
(
&
result
,
start_event
.
get
(),
stop_event
.
get
());
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed hipEventElapsedTime: "
+
hip_error
(
status
));
}
return
result
;
}
private:
// TODO: Make this a vector to support multiple devices
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
bool
measure_perf
=
false
;
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
};
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
e7268884
...
...
@@ -42,15 +42,17 @@ namespace gpu {
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
template
<
class
Op
>
struct
rocblas_gemm
{
Op
op
;
float
alpha
=
1
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
float
alpha
=
1
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -58,7 +60,9 @@ struct rocblas_gemm
return
pack_join
(
migraphx
::
reflect
(
self
.
op
,
f
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
int8_x4_format
,
"int8_x4_format"
)));
f
(
self
.
int8_x4_format
,
"int8_x4_format"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
}
std
::
string
name
()
const
...
...
@@ -74,13 +78,14 @@ struct rocblas_gemm
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
(
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
// if gemm and add are fused
if
(
in_shapes
.
size
()
>
2
)
{
auto
cmat_shape
=
in_shapes
.
back
();
check_shapes
{{
cmat_shape
},
*
this
}.
not_transposed
().
not_broadcasted
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
...
...
@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
return
op_out_shape
;
return
transpose_batch
(
op_out_shape
,
trans_batch
)
;
}
return
op
.
compute_shape
(
in_shapes
);
return
transpose_batch
(
op
.
compute_shape
(
in_shapes
)
,
trans_batch
)
;
}
argument
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
e7268884
...
...
@@ -37,6 +37,8 @@ namespace gpu {
struct
context
;
std
::
string
hip_error
(
int
error
);
argument
allocate_gpu
(
const
shape
&
s
,
bool
host
=
false
);
argument
register_on_gpu
(
const
argument
&
arg
);
...
...
src/targets/gpu/include/migraphx/gpu/kernel.hpp
View file @
e7268884
...
...
@@ -50,17 +50,22 @@ struct kernel
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
;
const
std
::
vector
<
kernel_argument
>&
args
,
hipEvent_t
start
=
nullptr
,
hipEvent_t
stop
=
nullptr
)
const
;
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
;
std
::
vector
<
void
*>
args
,
hipEvent_t
start
=
nullptr
,
hipEvent_t
stop
=
nullptr
)
const
;
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
)
const
template
<
class
...
Ts
>
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
Ts
...
zs
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...});
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...}
,
zs
...
);
};
}
...
...
src/targets/gpu/kernel.cpp
View file @
e7268884
...
...
@@ -80,7 +80,9 @@ void launch_kernel(hipFunction_t fun,
std
::
size_t
global
,
std
::
size_t
local
,
void
*
kernargs
,
std
::
size_t
size
)
std
::
size_t
size
,
hipEvent_t
start
,
hipEvent_t
stop
)
{
assert
(
global
>
0
);
assert
(
local
>
0
);
...
...
@@ -97,34 +99,55 @@ void launch_kernel(hipFunction_t fun,
#endif
};
auto
status
=
hipExtModuleLaunchKernel
(
fun
,
global
,
1
,
1
,
local
,
1
,
1
,
0
,
stream
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
));
auto
status
=
hipExtModuleLaunchKernel
(
fun
,
global
,
1
,
1
,
local
,
1
,
1
,
0
,
stream
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
),
start
,
stop
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to launch kernel: "
+
hip_error
(
status
));
if
(
stop
!=
nullptr
)
{
status
=
hipEventSynchronize
(
stop
);
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to sync event: "
+
hip_error
(
status
));
}
}
void
kernel
::
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
std
::
vector
<
void
*>
args
,
hipEvent_t
start
,
hipEvent_t
stop
)
const
{
assert
(
impl
!=
nullptr
);
void
*
kernargs
=
args
.
data
();
std
::
size_t
size
=
args
.
size
()
*
sizeof
(
void
*
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
,
size
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
,
size
,
start
,
stop
);
}
void
kernel
::
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
const
std
::
vector
<
kernel_argument
>&
args
,
hipEvent_t
start
,
hipEvent_t
stop
)
const
{
assert
(
impl
!=
nullptr
);
std
::
vector
<
char
>
kernargs
=
pack_args
(
args
);
std
::
size_t
size
=
kernargs
.
size
();
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
.
data
(),
size
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
.
data
(),
size
,
start
,
stop
);
}
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
e7268884
...
...
@@ -134,8 +134,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
pack_int8_args
{},
...
...
@@ -144,6 +142,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
replace_allocate
{
gpu_allocation_model
{},
options
.
offload_copy
},
dead_code_elimination
{},
compile_ops
{
&
ctx
},
dead_code_elimination
{},
write_literals
{
&
ctx
},
...
...
test/include/basic_ops.hpp
View file @
e7268884
...
...
@@ -186,9 +186,10 @@ struct nop
migraphx
::
shape
compute_shape
(
const
std
::
vector
<
migraphx
::
shape
>&
)
const
{
return
{};
}
};
inline
migraphx
::
literal
get_2x2
()
inline
migraphx
::
literal
get_2x2
(
int
base
=
0
)
{
return
migraphx
::
literal
{{
migraphx
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
return
migraphx
::
literal
{{
migraphx
::
shape
::
float_type
,
{
2
,
2
}},
{
base
+
1
,
base
+
2
,
base
+
3
,
base
+
4
}};
}
inline
migraphx
::
literal
get_2x2_transposed
()
...
...
test/simplify_algebra_test.cpp
View file @
e7268884
...
...
@@ -30,7 +30,6 @@
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void
run_pass
(
migraphx
::
module
&
m
)
...
...
@@ -358,7 +357,33 @@ TEST_CASE(simplify_mul_add)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast
)
TEST_CASE
(
simplify_dot_add
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
one
=
m1
.
add_literal
(
get_2x2
());
auto
two
=
m1
.
add_literal
(
get_2x2
(
1
));
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
one
,
x
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
two
);
m1
.
add_instruction
(
pass_op
{},
dot
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
one
=
m2
.
add_literal
(
get_2x2
());
auto
two
=
m2
.
add_literal
(
get_2x2
(
1
));
auto
dot1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x
,
two
);
auto
dot2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
one
,
two
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot1
,
dot2
);
m2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast1
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
...
...
@@ -383,6 +408,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast2
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
{
migraphx
::
module
m
;
...
...
@@ -1477,6 +1527,48 @@ TEST_CASE(simplify_dot_horiz_flipped)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
// test if contiguous is added as necessary for reshapes
TEST_CASE
(
simplify_dot_horiz_reshape
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
4
,
4
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
b
);
auto
x_rsp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x
);
auto
y_rsp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
x_rsp
,
y_rsp
});
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
concat
);
auto
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
4
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
dot
);
auto
x_cont
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
x_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x_cont
);
auto
y_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
x_rsp
,
y_rsp
});
m2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_conv_horiz
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
8
,
3
,
64
,
64
}};
...
...
@@ -1782,13 +1874,19 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice
)
template
<
std
::
size_t
BS
,
bool
TransposeInput
>
void
reorder_reshape_slice
()
{
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
create_m1
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
if
(
TransposeInput
)
{
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
},
{
165120
,
1
,
128
}};
}
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
640
}}}),
input
);
...
...
@@ -1803,7 +1901,7 @@ TEST_CASE(reorder_reshape_slice)
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
128
,
10
,
64
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
128
,
10
,
64
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
...
...
@@ -1815,16 +1913,23 @@ TEST_CASE(reorder_reshape_slice)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m2
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
128
,
30
,
64
};
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
if
(
TransposeInput
)
{
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
},
{
165120
,
1
,
128
}};
}
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
rsp_input
=
input
;
if
(
TransposeInput
)
{
rsp_input
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
{
input
});
}
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
128
,
30
,
64
};
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
rsp_input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
10
}}}),
r
);
...
...
@@ -1843,27 +1948,25 @@ TEST_CASE(reorder_reshape_slice)
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m2
.
add_return
({
ret
});
return
m2
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
run_pass
(
m1
);
auto
m2
=
create_m2
(
batch_size
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
1
);
test
(
4
);
test
(
8
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_move_axis1
)
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
1
,
true
>
);
// test if contiguous is added as necessary if
// input is transposed
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
4
,
true
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
8
,
true
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
1
,
false
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
4
,
false
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
8
,
false
>
);
template
<
std
::
size_t
BS
>
void
reorder_reshape_slice_move_axis1
()
{
auto
create_m1
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
256
,
96
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
256
,
96
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
...
...
@@ -1878,7 +1981,7 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
64
,
4
,
32
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
64
,
4
,
32
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
...
...
@@ -1890,50 +1993,45 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
256
,
96
}};
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
256
,
96
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
input
=
m
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
64
,
4
,
96
};
auto
rsp
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
.
add_instruction
(
auto
input
=
m
2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
64
,
4
,
96
};
auto
rsp
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
rsp
);
auto
t0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc0
);
auto
slc1
=
m
.
add_instruction
(
auto
t0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc0
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
rsp
);
auto
t1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc1
);
auto
slc2
=
m
.
add_instruction
(
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc1
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
96
}}}),
rsp
);
auto
t2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm1
}}),
slc2
);
auto
sum
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m
.
add_return
({
ret
});
return
m
;
};
auto
t2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm1
}}),
slc2
);
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
auto
m2
=
create_m2
(
batch_size
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m2
.
add_return
({
ret
});
};
test
(
4
);
tes
t
(
8
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sor
t
(
)
);
}
TEST_CASE_REGISTER
(
reorder_reshape_slice_move_axis1
<
4
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice_move_axis1
<
8
>
);
TEST_CASE
(
reorder_reshape_slice_move_axis2
)
{
auto
create_m1
=
[]
{
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
...
...
@@ -1955,32 +2053,26 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[]
{
migraphx
::
module
m
;
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m
.
add_parameter
(
"input"
,
s
);
auto
input
=
m
2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
1
,
16
,
8
,
96
};
auto
rsp
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
.
add_instruction
(
auto
rsp
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
rsp
);
auto
slc1
=
m
.
add_instruction
(
auto
slc1
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
rsp
);
auto
slc2
=
m
.
add_instruction
(
auto
slc2
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
96
}}}),
rsp
);
auto
sum
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m
.
add_return
({
ret
});
return
m
;
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
});
};
auto
m1
=
create_m1
();
auto
m2
=
create_m2
();
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
...
...
@@ -2020,15 +2112,14 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_diff_dims
)
template
<
std
::
size_t
BS
>
void
reorder_reshape_slice_diff_dims
()
{
auto
create_m1
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
96
,
96
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
96
,
96
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
input
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
input
);
...
...
@@ -2039,34 +2130,31 @@ TEST_CASE(reorder_reshape_slice_diff_dims)
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
batch_size
),
48
,
2
,
32
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
c2
);
m1
.
add_return
({
r0
,
r1
,
r2
});
return
m1
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
auto
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
4
);
test
(
8
);
auto
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_slice_trans
)
TEST_CASE_REGISTER
(
reorder_reshape_slice_diff_dims
<
4
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice_diff_dims
<
8
>
);
template
<
std
::
size_t
BS
>
void
reorder_slice_trans
()
{
std
::
vector
<
int64_t
>
perm
=
{
0
,
2
,
1
};
auto
create_m1
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
640
}}}),
input
);
...
...
@@ -2084,13 +2172,11 @@ TEST_CASE(reorder_slice_trans)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m2
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
input
);
...
...
@@ -2104,26 +2190,21 @@ TEST_CASE(reorder_slice_trans)
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
});
return
m2
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
run_pass
(
m1
);
auto
m2
=
create_m2
(
batch_size
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
1
);
test
(
8
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_slice_trans_diff_perm
)
TEST_CASE_REGISTER
(
reorder_slice_trans
<
1
>
);
TEST_CASE_REGISTER
(
reorder_slice_trans
<
8
>
);
template
<
std
::
size_t
BS
>
void
reorder_slice_trans_diff_perm
()
{
auto
create_m1
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
1
,
2
};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
...
...
@@ -2146,21 +2227,16 @@ TEST_CASE(reorder_slice_trans_diff_perm)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
run_pass
(
m1
);
auto
m2
=
m1
;
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
1
);
test
(
4
);
run_pass
(
m1
);
auto
m2
=
m1
;
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE_REGISTER
(
reorder_slice_trans_diff_perm
<
1
>
);
TEST_CASE_REGISTER
(
reorder_slice_trans_diff_perm
<
4
>
);
TEST_CASE
(
reorder_slice_ins_deps
)
{
auto
create_module
=
[]
{
...
...
test/verify/gemm_add_broadcast1.cpp
0 → 100644
View file @
e7268884
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_broadcast1
:
verify_program
<
gemm_add_broadcast1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
return
p
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment