Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c65ab678
Unverified
Commit
c65ab678
authored
Aug 07, 2023
by
Charlie Lin
Committed by
GitHub
Aug 07, 2023
Browse files
Change check_shapes to templated class (#2011)
parent
ae4cdf5a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
24 deletions
+37
-24
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+27
-22
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+5
-1
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
+5
-1
No files found.
src/include/migraphx/check_shapes.hpp
View file @
c65ab678
...
@@ -34,21 +34,37 @@
...
@@ -34,21 +34,37 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
struct
check_shapes
{
{
const
shape
*
begin
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
const
shape
*
end
;
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
std
::
string
name
;
bool
dynamic_allowed
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
...
@@ -56,7 +72,7 @@ struct check_shapes
...
@@ -56,7 +72,7 @@ struct check_shapes
template
<
class
Op
>
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
{
check_dynamic
();
check_dynamic
();
}
}
...
@@ -81,8 +97,6 @@ struct check_shapes
...
@@ -81,8 +97,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
0
;
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
return
end
-
begin
;
}
}
...
@@ -131,8 +145,6 @@ struct check_shapes
...
@@ -131,8 +145,6 @@ struct check_shapes
*/
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
if
(
begin
->
max_lens
().
size
()
!=
n
)
...
@@ -148,8 +160,6 @@ struct check_shapes
...
@@ -148,8 +160,6 @@ struct check_shapes
*/
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
>
n
)
if
(
begin
->
max_lens
().
size
()
>
n
)
...
@@ -166,8 +176,6 @@ struct check_shapes
...
@@ -166,8 +176,6 @@ struct check_shapes
*/
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
if
(
begin
!=
end
)
{
{
if
(
begin
->
max_lens
().
size
()
<
n
)
if
(
begin
->
max_lens
().
size
()
<
n
)
...
@@ -330,8 +338,6 @@ struct check_shapes
...
@@ -330,8 +338,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
}
...
@@ -341,8 +347,6 @@ struct check_shapes
...
@@ -341,8 +347,6 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
true
;
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
}
...
@@ -351,17 +355,13 @@ struct check_shapes
...
@@ -351,17 +355,13 @@ struct check_shapes
{
{
if
(
begin
==
end
)
if
(
begin
==
end
)
return
false
;
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
{
if
(
i
>=
size
())
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
if
(
i
<
0
)
return
end
-
i
;
return
end
-
i
;
return
begin
+
i
;
return
begin
+
i
;
...
@@ -394,6 +394,11 @@ struct check_shapes
...
@@ -394,6 +394,11 @@ struct check_shapes
}
}
};
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/cpu/gemm.cpp
View file @
c65ab678
...
@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
...
@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX
(
ARG_BIAS
)};
MIGRAPHX_DNNL_PREFIX
(
ARG_BIAS
)};
}
}
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
not_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
not_broadcasted
();
}
dnnl
::
matmul
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
dnnl
::
matmul
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
{
{
...
...
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
View file @
c65ab678
...
@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
...
@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
}
}
// dnnl has some issues with non-packed inputs
// dnnl has some issues with non-packed inputs
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
packed_or_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
packed_or_broadcasted
();
}
std
::
string
name
()
const
{
return
"dnnl::"
+
op
.
name
();
}
std
::
string
name
()
const
{
return
"dnnl::"
+
op
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment