Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
038a4c52
Commit
038a4c52
authored
Aug 22, 2018
by
wsttiger
Browse files
Merged from master still debugging resnet
parents
06cc4f8f
905d4ab0
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
210 additions
and
105 deletions
+210
-105
.clang-tidy
.clang-tidy
+3
-1
CMakeLists.txt
CMakeLists.txt
+0
-3
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+2
-1
src/include/migraph/argument.hpp
src/include/migraph/argument.hpp
+5
-3
src/include/migraph/builtin.hpp
src/include/migraph/builtin.hpp
+15
-6
src/include/migraph/check_context.hpp
src/include/migraph/check_context.hpp
+2
-2
src/include/migraph/errors.hpp
src/include/migraph/errors.hpp
+2
-2
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+25
-4
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+20
-6
src/include/migraph/instruction_ref.hpp
src/include/migraph/instruction_ref.hpp
+1
-0
src/include/migraph/literal.hpp
src/include/migraph/literal.hpp
+20
-12
src/include/migraph/operation.hpp
src/include/migraph/operation.hpp
+21
-18
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+49
-31
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+9
-5
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+5
-6
src/include/migraph/stringutils.hpp
src/include/migraph/stringutils.hpp
+2
-2
src/include/migraph/tensor_view.hpp
src/include/migraph/tensor_view.hpp
+2
-1
src/include/migraph/time.hpp
src/include/migraph/time.hpp
+19
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+5
-0
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+3
-2
No files found.
.clang-tidy
View file @
038a4c52
CheckOptions:
CheckOptions:
- key: modernize-loop-convert.MinConfidence
value: risky
- key: modernize-loop-convert.NamingStyle
- key: modernize-loop-convert.NamingStyle
value: lower_case
value: lower_case
- key: readability-function-size.BranchThreshold
- key: readability-function-size.BranchThreshold
...
...
CMakeLists.txt
View file @
038a4c52
...
@@ -36,9 +36,7 @@ include(ROCMClangTidy)
...
@@ -36,9 +36,7 @@ include(ROCMClangTidy)
rocm_enable_clang_tidy
(
rocm_enable_clang_tidy
(
CHECKS
CHECKS
*
*
-cert-env33-c
-android-cloexec-fopen
-android-cloexec-fopen
-cert-msc50-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-deprecated-declarations
...
@@ -72,7 +70,6 @@ rocm_enable_clang_tidy(
...
@@ -72,7 +70,6 @@ rocm_enable_clang_tidy(
-modernize-pass-by-value
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-default-member-init
-modernize-use-transparent-functors
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-braces-around-statements
-readability-else-after-return
-readability-else-after-return
-readability-named-parameter
-readability-named-parameter
...
...
src/eliminate_contiguous.cpp
View file @
038a4c52
...
@@ -5,10 +5,11 @@
...
@@ -5,10 +5,11 @@
#include <migraph/iterator_for.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/stringutils.hpp>
#include <utility>
namespace
migraph
{
namespace
migraph
{
bool
try_compute_shape
(
operation
op
,
std
::
vector
<
instruction_ref
>
args
)
bool
try_compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>
&
args
)
{
{
try
try
{
{
...
...
src/include/migraph/argument.hpp
View file @
038a4c52
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/raw_data.hpp>
#include <migraph/raw_data.hpp>
#include <functional>
#include <functional>
#include <utility>
namespace
migraph
{
namespace
migraph
{
...
@@ -18,16 +19,17 @@ struct argument : raw_data<argument>
...
@@ -18,16 +19,17 @@ struct argument : raw_data<argument>
{
{
argument
()
{}
argument
()
{}
argument
(
shape
s
)
:
m_shape
(
s
)
argument
(
const
shape
&
s
)
:
m_shape
(
s
)
{
{
std
::
vector
<
char
>
buffer
(
s
.
bytes
());
std
::
vector
<
char
>
buffer
(
s
.
bytes
());
// TODO: Move vector
// TODO: Move vector
data
=
[
=
]()
mutable
{
return
buffer
.
data
();
};
data
=
[
=
]()
mutable
{
return
buffer
.
data
();
};
}
}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
m_shape
(
s
)
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
std
::
move
(
d
)
),
m_shape
(
s
td
::
move
(
s
)
)
{}
template
<
class
T
>
template
<
class
T
>
argument
(
shape
s
,
T
*
d
)
:
data
([
d
]
{
return
reinterpret_cast
<
char
*>
(
d
);
}),
m_shape
(
s
)
argument
(
shape
s
,
T
*
d
)
:
data
([
d
]
{
return
reinterpret_cast
<
char
*>
(
d
);
}),
m_shape
(
std
::
move
(
s
))
{
{
}
}
...
...
src/include/migraph/builtin.hpp
View file @
038a4c52
...
@@ -12,24 +12,33 @@ namespace builtin {
...
@@ -12,24 +12,33 @@ namespace builtin {
struct
literal
struct
literal
{
{
std
::
string
name
()
const
{
return
"@literal"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
};
};
struct
outline
struct
outline
{
{
shape
s
;
shape
s
;
std
::
string
name
()
const
{
return
"@outline"
;
}
std
::
string
name
()
const
{
return
"@outline"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
return
s
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
return
s
;
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
};
};
struct
param
struct
param
{
{
std
::
string
parameter
;
std
::
string
parameter
;
std
::
string
name
()
const
{
return
"@param"
;
}
std
::
string
name
()
const
{
return
"@param"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
{
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
...
...
src/include/migraph/check_context.hpp
View file @
038a4c52
...
@@ -11,8 +11,8 @@ struct check_context
...
@@ -11,8 +11,8 @@ struct check_context
struct
op
struct
op
{
{
std
::
string
name
()
const
{
return
"check_context"
;
}
std
::
string
name
()
const
{
return
"check_context"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
return
{};
}
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
)
const
{
return
{};
}
argument
compute
(
context
&
ctx
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
T
*
x
=
any_cast
<
T
>
(
&
ctx
);
T
*
x
=
any_cast
<
T
>
(
&
ctx
);
if
(
x
==
nullptr
)
if
(
x
==
nullptr
)
...
...
src/include/migraph/errors.hpp
View file @
038a4c52
...
@@ -10,7 +10,7 @@ namespace migraph {
...
@@ -10,7 +10,7 @@ namespace migraph {
/// Represents exceptions that can be thrown by migraphlib
/// Represents exceptions that can be thrown by migraphlib
struct
exception
:
std
::
runtime_error
struct
exception
:
std
::
runtime_error
{
{
exception
(
std
::
string
msg
=
""
)
:
std
::
runtime_error
(
msg
)
{}
exception
(
const
std
::
string
&
msg
=
""
)
:
std
::
runtime_error
(
msg
)
{}
};
};
/**
/**
...
@@ -20,7 +20,7 @@ struct exception : std::runtime_error
...
@@ -20,7 +20,7 @@ struct exception : std::runtime_error
* @param message Custom message for the error
* @param message Custom message for the error
* @return Exceptions
* @return Exceptions
*/
*/
inline
exception
make_exception
(
std
::
string
context
,
std
::
string
message
=
""
)
inline
exception
make_exception
(
const
std
::
string
&
context
,
const
std
::
string
&
message
=
""
)
{
{
return
{
context
+
": "
+
message
};
return
{
context
+
": "
+
message
};
}
}
...
...
src/include/migraph/generate.hpp
View file @
038a4c52
...
@@ -8,12 +8,33 @@
...
@@ -8,12 +8,33 @@
namespace
migraph
{
namespace
migraph
{
template
<
class
T
>
template
<
class
T
>
std
::
vector
<
T
>
generate_tensor_data
(
migraph
::
shape
s
,
std
::
mt19937
::
result_type
seed
=
0
)
struct
xorshf96_generator
{
unsigned
long
max
=
31
;
unsigned
long
x
=
123456789
;
unsigned
long
y
=
362436069
;
unsigned
long
z
=
521288629
;
constexpr
T
operator
()()
noexcept
{
x
^=
x
<<
16U
;
x
^=
x
>>
5U
;
x
^=
x
<<
1U
;
unsigned
long
t
=
x
;
x
=
y
;
y
=
z
;
z
=
t
^
x
^
y
;
return
z
%
max
;
}
};
template
<
class
T
>
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
std
::
mt19937
::
result_type
)
{
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
mt19937
engine
{
seed
};
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{});
std
::
uniform_real_distribution
<>
dist
;
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
&
]
{
return
dist
(
engine
);
});
return
result
;
return
result
;
}
}
...
...
src/include/migraph/instruction.hpp
View file @
038a4c52
...
@@ -8,10 +8,11 @@
...
@@ -8,10 +8,11 @@
#include <migraph/operation.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <string>
#include <utility>
namespace
migraph
{
namespace
migraph
{
shape
compute_shape
(
operation
op
,
std
::
vector
<
instruction_ref
>
args
);
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>
&
args
);
struct
instruction
struct
instruction
{
{
...
@@ -25,14 +26,14 @@ struct instruction
...
@@ -25,14 +26,14 @@ struct instruction
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
// internal
// internal
void
replace
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
void
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
{
op
=
o
;
op
=
std
::
move
(
o
)
;
replace
(
std
::
move
(
r
)
);
replace
(
r
);
replace
(
std
::
move
(
args
));
replace
(
std
::
move
(
args
));
}
}
void
replace
(
shape
r
)
void
replace
(
const
shape
&
r
)
{
{
if
(
r
!=
result
)
if
(
r
!=
result
)
{
{
...
@@ -155,7 +156,7 @@ inline void replace_argument(instruction_ref ins, instruction_ref old, instructi
...
@@ -155,7 +156,7 @@ inline void replace_argument(instruction_ref ins, instruction_ref old, instructi
// TODO: Move to a cpp file
// TODO: Move to a cpp file
// TODO: Use const ref for vector
// TODO: Use const ref for vector
inline
shape
compute_shape
(
operation
op
,
std
::
vector
<
instruction_ref
>
args
)
inline
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>
&
args
)
{
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
std
::
transform
(
...
@@ -165,4 +166,17 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args)
...
@@ -165,4 +166,17 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args)
}
// namespace migraph
}
// namespace migraph
namespace
std
{
template
<
>
struct
hash
<
migraph
::
instruction_ref
>
{
using
argument_type
=
migraph
::
instruction_ref
;
using
result_type
=
std
::
size_t
;
result_type
operator
()(
const
argument_type
&
x
)
const
noexcept
{
return
std
::
hash
<
migraph
::
instruction
*>
{}(
&*
x
);
}
};
}
// namespace std
#endif
#endif
src/include/migraph/instruction_ref.hpp
View file @
038a4c52
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#include <list>
#include <list>
#include <functional>
namespace
migraph
{
namespace
migraph
{
...
...
src/include/migraph/literal.hpp
View file @
038a4c52
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include <migraph/tensor_view.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
#include <migraph/raw_data.hpp>
#include <memory>
namespace
migraph
{
namespace
migraph
{
/**
/**
...
@@ -18,51 +20,57 @@ struct literal : raw_data<literal>
...
@@ -18,51 +20,57 @@ struct literal : raw_data<literal>
literal
()
{}
literal
()
{}
template
<
class
T
>
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
)
,
0
),
m_shape
(
shape
::
get_type
<
T
>
{})
literal
(
T
x
)
:
buffer
(
std
::
make_unique
<
char
[]
>
(
sizeof
(
T
)
)
),
m_shape
(
shape
::
get_type
<
T
>
{})
{
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
buffer
.
get
()))
=
x
;
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
literal
(
const
shape
&
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
std
::
make_unique
<
char
[]
>
(
s
.
bytes
())),
m_shape
(
s
)
{
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
fill
(
x
.
begin
(),
x
.
end
());
fill
(
x
.
begin
(),
x
.
end
());
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
literal
(
const
shape
&
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
std
::
make_unique
<
char
[]
>
(
s
.
bytes
())),
m_shape
(
s
)
{
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
fill
(
x
.
begin
(),
x
.
end
());
fill
(
x
.
begin
(),
x
.
end
());
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
literal
(
const
shape
&
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
std
::
make_unique
<
char
[]
>
(
s
.
bytes
())),
m_shape
(
s
)
{
{
fill
(
start
,
end
);
fill
(
start
,
end
);
}
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
m_shape
(
s
)
{}
literal
(
const
shape
&
s
,
const
char
*
x
)
:
buffer
(
std
::
make_unique
<
char
[]
>
(
s
.
bytes
())),
m_shape
(
s
)
{
std
::
copy
(
x
,
x
+
s
.
bytes
(),
buffer
.
get
());
}
/// Whether data is available
/// Whether data is available
bool
empty
()
const
{
return
this
->
buffer
.
empty
()
;
}
bool
empty
()
const
{
return
this
->
buffer
==
nullptr
;
}
/// Provides a raw pointer to the data
/// Provides a raw pointer to the data
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
get
();
}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
/// Convert the data to an argument
/// Convert the data to an argument
argument
get_argument
()
const
argument
get_argument
()
const
{
{
auto
b
=
buffer
;
std
::
vector
<
char
>
b
(
buffer
.
get
(),
buffer
.
get
()
+
m_shape
.
bytes
())
;
return
{
m_shape
,
[
b
]()
mutable
{
return
b
.
data
();
}};
return
{
m_shape
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
}
private:
private:
std
::
vecto
r
<
char
>
buffer
;
std
::
shared_pt
r
<
char
>
buffer
;
shape
m_shape
;
shape
m_shape
;
template
<
class
Iterator
>
template
<
class
Iterator
>
...
@@ -70,13 +78,13 @@ struct literal : raw_data<literal>
...
@@ -70,13 +78,13 @@ struct literal : raw_data<literal>
{
{
if
(
m_shape
.
standard
())
if
(
m_shape
.
standard
())
{
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
}
}
else
else
{
{
auto
it
=
start
;
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
data
()));
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
it
++
;
it
++
;
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
...
...
src/include/migraph/operation.hpp
View file @
038a4c52
...
@@ -25,7 +25,7 @@ struct operation
...
@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
/// exception.
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
input
)
const
;
/**
/**
* @brief This performs the operation's computation
* @brief This performs the operation's computation
*
*
...
@@ -37,7 +37,7 @@ struct operation
...
@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
* the same the `output` shape.
*/
*/
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>
&
input
)
const
;
/// An optional stream operator to print the operation. When this is not
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
/// implemented, it will just print the operation's name.
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
...
@@ -56,7 +56,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -56,7 +56,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
...
@@ -67,8 +68,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
...
@@ -67,8 +68,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
* struct operation
* struct operation
* {
* {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(
const
std::vector<shape>
&
input) const;
* argument compute(context& ctx,shape output,std::vector<argument> input) const;
* argument compute(context& ctx,
const
shape
&
output,
const
std::vector<argument>
&
input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
* };
*
*
...
@@ -137,17 +138,16 @@ struct operation
...
@@ -137,17 +138,16 @@ struct operation
return
(
*
this
).
private_detail_te_get_handle
().
name
();
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
std
::
move
(
input
)
)
;
return
(
*
this
).
private_detail_te_get_handle
().
compute_shape
(
input
);
}
}
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>
&
input
)
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
return
(
*
this
).
private_detail_te_get_handle
().
compute
(
ctx
,
output
,
input
);
ctx
,
std
::
move
(
output
),
std
::
move
(
input
));
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
@@ -163,10 +163,11 @@ struct operation
...
@@ -163,10 +163,11 @@ struct operation
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
=
0
;
virtual
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
=
0
;
virtual
argument
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
=
0
;
virtual
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
=
0
;
};
};
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
...
@@ -199,16 +200,18 @@ struct operation
...
@@ -199,16 +200,18 @@ struct operation
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
override
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
input
)
const
override
{
{
return
private_detail_te_value
.
compute_shape
(
std
::
move
(
input
)
)
;
return
private_detail_te_value
.
compute_shape
(
input
);
}
}
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
override
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
input
)
const
override
{
{
return
compute_op
(
private_detail_te_value
,
ctx
,
std
::
move
(
output
)
,
std
::
move
(
input
)
)
;
return
compute_op
(
private_detail_te_value
,
ctx
,
output
,
input
);
}
}
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
std
::
ostream
&
operator_shift_left
(
std
::
ostream
&
os
)
const
override
...
...
src/include/migraph/operators.hpp
View file @
038a4c52
...
@@ -7,12 +7,13 @@
...
@@ -7,12 +7,13 @@
#include <migraph/stringutils.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/streamutils.hpp>
#include <cmath>
#include <cmath>
#include <utility>
namespace
migraph
{
namespace
migraph
{
struct
not_computable
struct
not_computable
{
{
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -41,7 +42,7 @@ struct batch_norm_inference
...
@@ -41,7 +42,7 @@ struct batch_norm_inference
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -114,7 +115,7 @@ struct convolution
...
@@ -114,7 +115,7 @@ struct convolution
}
}
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -145,8 +146,8 @@ struct pooling
...
@@ -145,8 +146,8 @@ struct pooling
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
//
assert(lengths[0] < (input.lens()[2] + 2 * padding[0]));
assert
(
lengths
[
0
]
<
=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
//
assert(lengths[1] < (input.lens()[3] + 2 * padding[1]));
assert
(
lengths
[
1
]
<
=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
return
{
t
,
return
{
t
,
{
{
...
@@ -175,7 +176,7 @@ struct pooling
...
@@ -175,7 +176,7 @@ struct pooling
}};
}};
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -201,7 +202,7 @@ struct activation
...
@@ -201,7 +202,7 @@ struct activation
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -244,7 +245,14 @@ struct transpose
...
@@ -244,7 +245,14 @@ struct transpose
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
transpose
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}"
;
os
<<
"]"
;
return
os
;
}
}
};
};
...
@@ -262,7 +270,7 @@ struct contiguous
...
@@ -262,7 +270,7 @@ struct contiguous
}
}
return
{
t
,
lens
};
return
{
t
,
lens
};
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -309,13 +317,13 @@ struct reshape
...
@@ -309,13 +317,13 @@ struct reshape
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
)
,
std
::
move
(
args
.
front
().
data
)};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
{
{
os
<<
op
.
name
()
<<
"["
;
os
<<
op
.
name
()
<<
"["
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}
,
"
;
os
<<
"dims={"
<<
stream_range
(
op
.
dims
)
<<
"}"
;
os
<<
"]"
;
os
<<
"]"
;
return
os
;
return
os
;
}
}
...
@@ -339,7 +347,7 @@ struct gemm
...
@@ -339,7 +347,7 @@ struct gemm
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -359,7 +367,7 @@ struct unary
...
@@ -359,7 +367,7 @@ struct unary
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -439,26 +447,26 @@ struct flatten
...
@@ -439,26 +447,26 @@ struct flatten
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
auto
&&
lens
=
inputs
.
front
().
lens
();
if
(
axis
==
0
)
if
(
axis
>
lens
.
size
())
{
return
{
inputs
.
at
(
0
).
type
(),
{
1
,
inputs
.
at
(
0
).
elements
()}};
}
else
if
(
axis
<
lens
.
size
())
{
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
else
{
{
MIGRAPH_THROW
(
"axis for flatten must be less than tensor rank"
);
MIGRAPH_THROW
(
"axis for flatten must be less than tensor rank"
);
}
}
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
front
().
data
)};
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
flatten
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"axis="
<<
op
.
axis
;
os
<<
"]"
;
return
os
;
}
}
};
};
struct
broadcast
struct
broadcast
...
@@ -491,7 +499,14 @@ struct broadcast
...
@@ -491,7 +499,14 @@ struct broadcast
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
output_shape
,
std
::
move
(
args
.
at
(
1
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
1
).
data
)};
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
broadcast
&
op
)
{
os
<<
op
.
name
()
<<
"["
;
os
<<
"axis="
<<
op
.
axis
;
os
<<
"]"
;
return
os
;
}
}
};
};
...
@@ -503,7 +518,7 @@ struct binary
...
@@ -503,7 +518,7 @@ struct binary
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -533,12 +548,15 @@ struct outline
...
@@ -533,12 +548,15 @@ struct outline
{
{
shape
s
;
shape
s
;
std
::
string
name
()
const
{
return
"outline"
;
}
std
::
string
name
()
const
{
return
"outline"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
return
{
s
,
nullptr
};
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
return
{
s
,
nullptr
};
}
};
};
}
// namespace migraph
}
// namespace migraph
...
...
src/include/migraph/program.hpp
View file @
038a4c52
...
@@ -34,7 +34,7 @@ struct program
...
@@ -34,7 +34,7 @@ struct program
{
{
return
add_instruction
(
op
,
{
args
...});
return
add_instruction
(
op
,
{
args
...});
}
}
instruction_ref
add_instruction
(
operation
op
,
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
add_instruction
(
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
args
);
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instruction_ref
insert_instruction
(
instruction_ref
ins
,
operation
op
,
Ts
...
args
)
instruction_ref
insert_instruction
(
instruction_ref
ins
,
operation
op
,
Ts
...
args
)
...
@@ -42,15 +42,16 @@ struct program
...
@@ -42,15 +42,16 @@ struct program
return
insert_instruction
(
ins
,
op
,
{
args
...});
return
insert_instruction
(
ins
,
op
,
{
args
...});
}
}
instruction_ref
instruction_ref
insert_instruction
(
instruction_ref
ins
,
operation
op
,
std
::
vector
<
instruction_ref
>
args
);
insert_instruction
(
instruction_ref
ins
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
args
);
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instruction_ref
replace_instruction
(
instruction_ref
ins
,
operation
op
,
Ts
...
args
)
instruction_ref
replace_instruction
(
instruction_ref
ins
,
operation
op
,
Ts
...
args
)
{
{
return
replace_instruction
(
ins
,
op
,
{
args
...});
return
replace_instruction
(
ins
,
op
,
{
args
...});
}
}
instruction_ref
instruction_ref
replace_instruction
(
instruction_ref
ins
,
replace_instruction
(
instruction_ref
ins
,
operation
op
,
std
::
vector
<
instruction_ref
>
args
);
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
replace_instruction
(
instruction_ref
ins
,
instruction_ref
rep
);
instruction_ref
replace_instruction
(
instruction_ref
ins
,
instruction_ref
rep
);
...
@@ -67,7 +68,7 @@ struct program
...
@@ -67,7 +68,7 @@ struct program
instruction_ref
add_literal
(
literal
l
);
instruction_ref
add_literal
(
literal
l
);
instruction_ref
add_outline
(
shape
s
);
instruction_ref
add_outline
(
const
shape
&
s
);
instruction_ref
add_parameter
(
std
::
string
name
,
shape
s
);
instruction_ref
add_parameter
(
std
::
string
name
,
shape
s
);
...
@@ -79,6 +80,7 @@ struct program
...
@@ -79,6 +80,7 @@ struct program
bool
has_instruction
(
instruction_ref
ins
)
const
;
bool
has_instruction
(
instruction_ref
ins
)
const
;
std
::
size_t
size
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
begin
()
const
;
instruction_ref
end
()
const
;
instruction_ref
end
()
const
;
...
@@ -88,6 +90,8 @@ struct program
...
@@ -88,6 +90,8 @@ struct program
void
compile
(
const
target
&
t
);
void
compile
(
const
target
&
t
);
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
==
(
const
program
&
x
,
const
program
&
y
);
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
program
&
x
,
const
program
&
y
)
{
return
!
(
x
==
y
);
}
...
...
src/include/migraph/shape.hpp
View file @
038a4c52
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
#include <cassert>
#include <cassert>
#include <ostream>
#include <ostream>
#include <numeric>
#include <numeric>
#include <memory>
#include <migraph/errors.hpp>
#include <migraph/errors.hpp>
namespace
migraph
{
namespace
migraph
{
struct
shape_impl
;
struct
shape
struct
shape
{
{
...
@@ -136,7 +139,7 @@ struct shape
...
@@ -136,7 +139,7 @@ struct shape
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
void
visit_type
(
Visitor
v
)
const
{
{
switch
(
this
->
m_
type
)
switch
(
this
->
type
()
)
{
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
case x: v(as<t>()); return;
...
@@ -147,12 +150,8 @@ struct shape
...
@@ -147,12 +150,8 @@ struct shape
}
}
private:
private:
type_t
m_type
;
std
::
shared_ptr
<
const
shape_impl
>
impl
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_standard
;
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
std
::
string
type_string
()
const
;
std
::
string
type_string
()
const
;
};
};
...
...
src/include/migraph/stringutils.hpp
View file @
038a4c52
...
@@ -29,7 +29,7 @@ inline bool ends_with(const std::string& value, const std::string& suffix)
...
@@ -29,7 +29,7 @@ inline bool ends_with(const std::string& value, const std::string& suffix)
}
}
template
<
class
Strings
>
template
<
class
Strings
>
inline
std
::
string
join_strings
(
Strings
strings
,
std
::
string
delim
)
inline
std
::
string
join_strings
(
Strings
strings
,
const
std
::
string
&
delim
)
{
{
auto
it
=
strings
.
begin
();
auto
it
=
strings
.
begin
();
if
(
it
==
strings
.
end
())
if
(
it
==
strings
.
end
())
...
@@ -57,7 +57,7 @@ inline bool starts_with(const std::string& value, const std::string& prefix)
...
@@ -57,7 +57,7 @@ inline bool starts_with(const std::string& value, const std::string& prefix)
return
std
::
equal
(
prefix
.
begin
(),
prefix
.
end
(),
value
.
begin
());
return
std
::
equal
(
prefix
.
begin
(),
prefix
.
end
(),
value
.
begin
());
}
}
inline
std
::
string
remove_prefix
(
std
::
string
s
,
std
::
string
prefix
)
inline
std
::
string
remove_prefix
(
std
::
string
s
,
const
std
::
string
&
prefix
)
{
{
if
(
starts_with
(
s
,
prefix
))
if
(
starts_with
(
s
,
prefix
))
return
s
.
substr
(
prefix
.
length
());
return
s
.
substr
(
prefix
.
length
());
...
...
src/include/migraph/tensor_view.hpp
View file @
038a4c52
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <migraph/requires.hpp>
#include <migraph/requires.hpp>
#include <iostream>
#include <iostream>
#include <utility>
namespace
migraph
{
namespace
migraph
{
...
@@ -14,7 +15,7 @@ struct tensor_view
...
@@ -14,7 +15,7 @@ struct tensor_view
{
{
using
value_type
=
T
;
using
value_type
=
T
;
tensor_view
()
:
m_data
(
nullptr
)
{}
tensor_view
()
:
m_data
(
nullptr
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
m_data
(
d
),
m_shape
(
s
)
{}
tensor_view
(
shape
s
,
T
*
d
)
:
m_data
(
d
),
m_shape
(
s
td
::
move
(
s
)
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
...
...
src/include/migraph/time.hpp
0 → 100644
View file @
038a4c52
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#include <chrono>
namespace
migraph
{
template
<
class
Duration
,
class
F
>
auto
time
(
F
f
)
{
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
f
();
auto
finish
=
std
::
chrono
::
steady_clock
::
now
();
return
std
::
chrono
::
duration_cast
<
Duration
>
(
finish
-
start
).
count
();
}
}
// namespace migraph
#endif
src/onnx/CMakeLists.txt
View file @
038a4c52
...
@@ -16,6 +16,7 @@ add_executable(read_onnx read_onnx.cpp)
...
@@ -16,6 +16,7 @@ add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check
(
read_onnx
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx migraph_onnx
)
target_link_libraries
(
read_onnx migraph_onnx
)
add_executable
(
mnist mnist.cpp
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist migraph_cpu migraph_gpu migraph_onnx
)
target_link_libraries
(
mnist migraph_cpu migraph_gpu migraph_onnx
)
...
@@ -28,4 +29,8 @@ if(MIGRAPH_ENABLE_GPU)
...
@@ -28,4 +29,8 @@ if(MIGRAPH_ENABLE_GPU)
add_executable
(
verify_onnx verify_onnx.cpp
)
add_executable
(
verify_onnx verify_onnx.cpp
)
rocm_clang_tidy_check
(
verify_onnx
)
rocm_clang_tidy_check
(
verify_onnx
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
add_executable
(
perf_onnx perf_onnx.cpp
)
rocm_clang_tidy_check
(
perf_onnx
)
target_link_libraries
(
perf_onnx migraph_onnx migraph_cpu migraph_gpu
)
endif
()
endif
()
src/onnx/mnist.cpp
View file @
038a4c52
...
@@ -21,7 +21,8 @@ auto reverse_int(unsigned int i)
...
@@ -21,7 +21,8 @@ auto reverse_int(unsigned int i)
(
static_cast
<
unsigned
int
>
(
c3
)
<<
8u
)
+
c4
;
(
static_cast
<
unsigned
int
>
(
c3
)
<<
8u
)
+
c4
;
};
};
std
::
vector
<
float
>
read_mnist_images
(
std
::
string
full_path
,
int
&
number_of_images
,
int
&
image_size
)
std
::
vector
<
float
>
read_mnist_images
(
const
std
::
string
&
full_path
,
int
&
number_of_images
,
int
&
image_size
)
{
{
using
uchar
=
unsigned
char
;
using
uchar
=
unsigned
char
;
...
@@ -64,7 +65,7 @@ std::vector<float> read_mnist_images(std::string full_path, int& number_of_image
...
@@ -64,7 +65,7 @@ std::vector<float> read_mnist_images(std::string full_path, int& number_of_image
}
}
}
}
std
::
vector
<
int32_t
>
read_mnist_labels
(
std
::
string
full_path
,
int
&
number_of_labels
)
std
::
vector
<
int32_t
>
read_mnist_labels
(
const
std
::
string
&
full_path
,
int
&
number_of_labels
)
{
{
using
uchar
=
unsigned
char
;
using
uchar
=
unsigned
char
;
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment