Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c65ab678
Unverified
Commit
c65ab678
authored
Aug 07, 2023
by
Charlie Lin
Committed by
GitHub
Aug 07, 2023
Browse files
Change check_shapes to templated class (#2011)
parent
ae4cdf5a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
24 deletions
+37
-24
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+27
-22
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+5
-1
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
+5
-1
No files found.
src/include/migraphx/check_shapes.hpp
View file @
c65ab678
...
...
@@ -34,21 +34,37 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
{
const
shape
*
begin
;
const
shape
*
end
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
...
...
@@ -56,7 +72,7 @@ struct check_shapes
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
...
...
@@ -81,8 +97,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
}
...
...
@@ -131,8 +145,6 @@ struct check_shapes
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
...
...
@@ -148,8 +160,6 @@ struct check_shapes
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
>
n
)
...
...
@@ -166,8 +176,6 @@ struct check_shapes
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
<
n
)
...
...
@@ -330,8 +338,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
...
...
@@ -341,8 +347,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
...
...
@@ -351,17 +355,13 @@ struct check_shapes
{
if
(
begin
==
end
)
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
return
end
-
i
;
return
begin
+
i
;
...
...
@@ -394,6 +394,11 @@ struct check_shapes
}
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/cpu/gemm.cpp
View file @
c65ab678
...
...
@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX
(
ARG_BIAS
)};
}
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
not_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
not_broadcasted
();
}
dnnl
::
matmul
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
{
...
...
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
View file @
c65ab678
...
...
@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
}
// dnnl has some issues with non-packed inputs
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
packed_or_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
packed_or_broadcasted
();
}
std
::
string
name
()
const
{
return
"dnnl::"
+
op
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment