Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b75c83d8
Unverified
Commit
b75c83d8
authored
Jun 24, 2022
by
Paul Fultz II
Committed by
GitHub
Jun 24, 2022
Browse files
Use jit for contiguous operator (#1217)
* Jit contiguous
parent
8c35fa94
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
212 additions
and
75 deletions
+212
-75
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+3
-3
src/shape.cpp
src/shape.cpp
+14
-8
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+37
-1
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+39
-28
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+33
-1
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+9
-2
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+33
-25
test/reduce_dims.cpp
test/reduce_dims.cpp
+35
-7
test/shape_test.cpp
test/shape_test.cpp
+9
-0
No files found.
src/include/migraphx/ranges.hpp
View file @
b75c83d8
...
...
@@ -210,10 +210,10 @@ void replace(Range&& r, const T& old, const T& new_x)
std
::
replace
(
r
.
begin
(),
r
.
end
(),
old
,
new_x
);
}
template
<
class
R1
,
class
R2
>
bool
equal
(
R1
&&
r1
,
R2
&&
r2
)
template
<
class
R1
,
class
R2
,
class
...
Predicate
>
bool
equal
(
R1
&&
r1
,
R2
&&
r2
,
Predicate
...
pred
)
{
return
std
::
equal
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
r2
.
end
());
return
std
::
equal
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
r2
.
end
()
,
pred
...
);
}
template
<
class
R
>
...
...
src/shape.cpp
View file @
b75c83d8
...
...
@@ -61,9 +61,7 @@ struct shape_impl
{
assert
(
t
!=
shape
::
tuple_type
);
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero");
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
m_standard
=
this
->
elements
()
==
this
->
element_space
()
and
not
skips
()
and
std
::
is_sorted
(
m_strides
.
rbegin
(),
m_strides
.
rend
());
}
...
...
@@ -110,6 +108,15 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
// Does the shape skip over elements?
bool
skips
()
const
{
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
if
(
elements
()
==
1
)
return
false
;
return
std
::
none_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
==
1
;
});
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
};
...
...
@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool
shape
::
packed
()
const
{
return
this
->
sub_shapes
().
empty
()
and
this
->
elements
()
==
this
->
element_space
();
return
this
->
sub_shapes
().
empty
()
and
not
impl
->
skips
()
and
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
...
...
@@ -285,10 +293,8 @@ bool shape::transposed() const
bool
shape
::
broadcasted
()
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
return
std
::
any_of
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
[](
auto
x
)
{
return
x
==
0
;
});
}
bool
shape
::
scalar
()
const
...
...
src/targets/gpu/fuse_ops.cpp
View file @
b75c83d8
...
...
@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath>
#include <set>
...
...
@@ -1012,9 +1013,43 @@ struct find_commutative_broadcast
}
};
struct
find_contiguous
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::precompile_op"
,
{{
"op"
,
to_value
(
make_op
(
"contiguous"
))}}),
ins
->
inputs
());
}
};
struct
find_contiguous_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::contiguous"
)(
match
::
arg
(
0
)(
precompile_name
(
"pointwise"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
pw
=
ins
->
inputs
().
front
();
auto
alloc
=
ins
->
inputs
().
back
();
auto
args
=
pw
->
inputs
();
args
.
back
()
=
alloc
;
m
.
replace_instruction
(
ins
,
pw
->
get_operator
(),
args
,
pw
->
module_inputs
());
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
find_triadd
{});
match
::
find_matches
(
m
,
...
...
@@ -1036,6 +1071,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
// namespace gpu
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
b75c83d8
...
...
@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
static
std
::
size_t
oversubscribe_if
(
bool
b
)
{
...
...
@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
if
(
op
.
name
()
==
"contiguous"
)
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
"[](auto x) { return x; }"
},
{
"kernel"
,
"contiguous_kernel"
}}));
}
else
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
};
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
b75c83d8
...
...
@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for
(;
first
!=
last
;
++
first
)
{
init
=
op
(
st
d
::
move
(
init
),
*
first
);
init
=
op
(
st
atic_cast
<
T
&&>
(
init
),
*
first
);
}
return
init
;
}
...
...
@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return
d_first
;
}
template
<
class
InputIt
,
class
OutputIt
,
class
UnaryPredicate
>
constexpr
OutputIt
copy_if
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
UnaryPredicate
pred
)
{
for
(;
first
!=
last
;
++
first
)
{
if
(
pred
(
*
first
))
{
*
d_first
=
*
first
;
++
d_first
;
}
}
return
d_first
;
}
template
<
class
Iterator
,
class
Compare
>
constexpr
Iterator
is_sorted_until
(
Iterator
first
,
Iterator
last
,
Compare
comp
)
{
...
...
@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return
find_if
(
first
,
last
,
[
&
](
const
auto
&
x
)
{
return
x
==
value
;
});
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
any_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
find_if
(
first
,
last
,
p
)
!=
last
;
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
none_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
find_if
(
first
,
last
,
p
)
==
last
;
}
template
<
class
InputIt
,
class
UnaryPredicate
>
constexpr
bool
all_of
(
InputIt
first
,
InputIt
last
,
UnaryPredicate
p
)
{
return
none_of
(
first
,
last
,
[
=
](
auto
&&
x
)
{
return
not
p
(
x
);
});
}
template
<
class
Iterator1
,
class
Iterator2
>
constexpr
Iterator1
search
(
Iterator1
first
,
Iterator1
last
,
Iterator2
s_first
,
Iterator2
s_last
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
b75c83d8
...
...
@@ -41,8 +41,15 @@ struct implicit_conversion_op
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
if
constexpr
(
vec_size
<
T
>
()
==
0
)
{
return
x
;
}
else
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
}
template
<
class
U
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
b75c83d8
...
...
@@ -44,7 +44,7 @@ struct shape
constexpr
auto
element_space
()
const
{
return
_c
<
Strides
{}.
dot
(
Lens
{}
-
1
)
+
1
>
;
}
constexpr
auto
packed
()
const
{
return
elements
()
==
element_space
();
}
constexpr
auto
packed
()
const
{
return
not
skips
()
and
elements
()
==
element_space
();
}
constexpr
auto
broadcasted
()
const
{
return
_c
<
Strides
{}.
product
()
==
0
>
;
}
constexpr
auto
transposed
()
const
{
...
...
@@ -53,16 +53,9 @@ struct shape
if
(
shape
{}.
broadcasted
())
{
index_array
s
{};
index_int
j
=
0
;
for
(
index_int
i
=
0
;
i
<
s
.
size
();
i
++
)
{
if
(
lstrides
[
i
]
!=
0
)
{
s
[
j
]
=
lstrides
[
i
];
j
++
;
}
}
return
not
is_sorted
(
s
.
begin
(),
s
.
begin
()
+
j
,
greater
{});
auto
out
=
copy_if
(
lstrides
.
begin
(),
lstrides
.
end
(),
s
.
begin
(),
[](
auto
x
)
{
return
x
!=
0
;
});
return
not
is_sorted
(
s
.
begin
(),
out
,
greater
{});
}
else
{
...
...
@@ -70,6 +63,13 @@ struct shape
}
});
}
constexpr
auto
skips
()
const
{
return
return_c
([]
{
auto
lstrides
=
Strides
{};
return
none_of
(
lstrides
.
begin
(),
lstrides
.
end
(),
[](
auto
x
)
{
return
x
==
1
;
});
});
}
constexpr
auto
standard
()
const
{
return
packed
()
and
not
transposed
();
}
...
...
@@ -86,26 +86,34 @@ struct shape
constexpr
index_int
index
(
index_int
i
)
const
{
if
(
this
->
standard
())
{
MIGRAPHX_ASSERT
(
i
==
compute_index
(
i
));
return
i
;
}
else
{
const
auto
rank
=
this
->
lens
.
size
();
index_int
s
=
1
;
index_int
result
=
0
;
for
(
index_int
j
=
0
;
j
<
rank
;
j
++
)
{
const
index_int
k
=
rank
-
j
-
1
;
const
index_int
stride
=
this
->
strides
[
k
];
const
index_int
len
=
this
->
lens
[
k
];
const
index_int
slen
=
s
*
len
;
const
index_int
idx
=
(
i
%
slen
)
/
s
;
result
+=
stride
*
idx
;
s
=
slen
;
}
return
result
;
return
compute_index
(
i
);
}
}
constexpr
index_int
compute_index
(
index_int
i
)
const
{
const
auto
rank
=
this
->
lens
.
size
();
index_int
s
=
1
;
index_int
result
=
0
;
for
(
index_int
j
=
0
;
j
<
rank
;
j
++
)
{
const
index_int
k
=
rank
-
j
-
1
;
const
index_int
stride
=
this
->
strides
[
k
];
const
index_int
len
=
this
->
lens
[
k
];
const
index_int
slen
=
s
*
len
;
const
index_int
idx
=
(
i
%
slen
)
/
s
;
result
+=
stride
*
idx
;
s
=
slen
;
}
return
result
;
}
/// Convert single index into a multi-index
constexpr
index_array
multi
(
index_int
idx
)
const
{
...
...
test/reduce_dims.cpp
View file @
b75c83d8
...
...
@@ -23,6 +23,7 @@
*/
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
migraphx
::
shape
make_shape
(
std
::
vector
<
std
::
size_t
>
lens
)
...
...
@@ -35,6 +36,21 @@ migraphx::shape make_shape(std::vector<std::size_t> lens, std::vector<std::size_
return
{
migraphx
::
shape
::
float_type
,
std
::
move
(
lens
),
std
::
move
(
strides
)};
}
bool
verify_shape
(
const
migraphx
::
shape
&
s1
,
const
migraphx
::
shape
&
s2
)
{
if
(
s1
.
elements
()
!=
s2
.
elements
())
return
false
;
return
migraphx
::
all_of
(
migraphx
::
range
(
s1
.
elements
()),
[
&
](
auto
i
)
{
return
s1
.
index
(
i
)
==
s2
.
index
(
i
);
});
}
template
<
class
Range1
,
class
Range2
>
bool
verify_shapes
(
const
Range1
&
r1
,
const
Range2
&
r2
)
{
return
migraphx
::
equal
(
r1
,
r2
,
[](
const
auto
&
s1
,
const
auto
&
s2
)
{
return
verify_shape
(
s1
,
s2
);
});
}
TEST_CASE
(
same_standard
)
{
auto
is
=
make_shape
({
64
,
3
,
7
,
7
});
...
...
@@ -42,7 +58,7 @@ TEST_CASE(same_standard)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
is
,
is
,
is
};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
os
,
os
,
os
};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -53,7 +69,7 @@ TEST_CASE(same_broadcast1)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
is
,
make_shape
({
64
,
3
,
7
,
7
},
{
0
,
1
,
0
,
0
}),
is
};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
os
,
make_shape
({
64
,
3
,
7
*
7
},
{
0
,
1
,
0
}),
os
};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -64,7 +80,7 @@ TEST_CASE(same_broadcast2)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
is
,
make_shape
({
64
,
3
,
8
,
7
,
7
},
{
0
,
8
,
1
,
0
,
0
}),
is
};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
os
,
make_shape
({
64
,
8
*
3
,
7
*
7
},
{
0
,
1
,
0
}),
os
};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -75,7 +91,7 @@ TEST_CASE(same_transposed)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
is
,
migraphx
::
reorder_shape
(
is
,
{
0
,
1
,
3
,
2
}),
is
};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
os
,
migraphx
::
reorder_shape
(
os
,
{
0
,
2
,
1
}),
os
};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -86,7 +102,7 @@ TEST_CASE(different_masked1)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
is
,
make_shape
({
1
,
3
,
1
,
1
}),
is
};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
os
,
make_shape
({
1
,
3
,
1
}),
os
};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -98,7 +114,7 @@ TEST_CASE(different_masked2)
is
,
make_shape
({
1
,
3
,
1
,
1
}),
make_shape
({
64
,
1
,
7
,
7
})};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
os
,
make_shape
({
1
,
3
,
1
}),
make_shape
({
64
,
1
,
7
*
7
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -128,7 +144,7 @@ TEST_CASE(transposed1)
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
make_shape
({
8
,
28
,
4
,
56
*
56
}),
make_shape
({
8
,
28
,
4
,
56
*
56
},
{
351232
,
3136
,
87808
,
1
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -137,6 +153,7 @@ TEST_CASE(non_packed_empty1)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
make_shape
({
1
,
12
},
{
589824
,
64
})};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
make_shape
({
12
},
{
64
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -145,6 +162,7 @@ TEST_CASE(non_packed_empty2)
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
make_shape
({
12
,
1
},
{
64
,
589824
})};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
make_shape
({
12
},
{
64
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
...
...
@@ -155,6 +173,16 @@ TEST_CASE(single_dim)
EXPECT
(
ishapes
==
rshapes
);
}
TEST_CASE
(
step_broadcast_transpose
)
{
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
make_shape
({
1
,
2
,
2
,
1
},
{
0
,
0
,
3
,
6
}),
make_shape
({
1
,
2
,
2
,
1
},
{
4
,
2
,
1
,
1
})};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
make_shape
({
2
,
2
},
{
0
,
3
}),
make_shape
({
2
,
2
},
{
2
,
1
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
verify_shapes
(
ishapes
,
rshapes
));
EXPECT
(
eshapes
==
rshapes
);
}
TEST_CASE
(
empty
)
{
auto
rshapes
=
migraphx
::
reduce_dims
({});
...
...
test/shape_test.cpp
View file @
b75c83d8
...
...
@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5)
EXPECT
(
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_step_broadcasted
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
0
,
3
}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_default_copy
)
{
migraphx
::
shape
s1
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment