Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
16e5b5d0
Commit
16e5b5d0
authored
Feb 28, 2022
by
Shucai Xiao
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into layernorm_half2
parents
5f37917f
c1b56607
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
1 deletion
+26
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+3
-0
src/shape.cpp
src/shape.cpp
+11
-0
test/shape_test.cpp
test/shape_test.cpp
+11
-0
No files found.
CMakeLists.txt
View file @
16e5b5d0
...
...
@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
1
)
rocm_setup_version
(
VERSION 2.
2
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
src/include/migraphx/shape.hpp
View file @
16e5b5d0
...
...
@@ -131,6 +131,8 @@ struct shape
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_lens
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_type
(
type_t
t
)
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
...
...
@@ -226,6 +228,7 @@ struct shape
std
::
size_t
element_space
()
const
;
private:
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
);
std
::
shared_ptr
<
const
shape_impl
>
impl
;
};
...
...
src/shape.cpp
100755 → 100644
View file @
16e5b5d0
...
...
@@ -86,6 +86,8 @@ struct shape_impl
return
std
::
accumulate
(
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
};
const
std
::
vector
<
shape
::
type_t
>&
shape
::
types
()
...
...
@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
shape
::
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
)
...
...
@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return
this
->
with_lens
(
this
->
type
(),
l
);
}
shape
shape
::
with_type
(
type_t
t
)
const
{
auto
c
=
impl
->
copy
();
c
->
m_type
=
t
;
return
{
c
};
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
...
...
test/shape_test.cpp
View file @
16e5b5d0
...
...
@@ -608,4 +608,15 @@ TEST_CASE(cpp_type_name)
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
shape
::
cpp_type
(
migraphx
::
shape
::
tuple_type
);
}));
}
TEST_CASE
(
test_with_type
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
EXPECT
(
s
.
type
()
==
migraphx
::
shape
::
float_type
);
auto
new_s
=
s
.
with_type
(
migraphx
::
shape
::
half_type
);
EXPECT
(
s
.
type
()
==
migraphx
::
shape
::
float_type
);
EXPECT
(
s
.
type
()
!=
new_s
.
type
());
EXPECT
(
s
.
lens
()
==
new_s
.
lens
());
EXPECT
(
s
.
strides
()
==
new_s
.
strides
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment