Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
eb0168fb
Commit
eb0168fb
authored
Aug 19, 2019
by
Paul
Browse files
Fill parameters with 1s from the driver
parent
0628e570
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
7 deletions
+101
-7
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+49
-4
src/driver/main.cpp
src/driver/main.cpp
+11
-1
src/driver/perf.cpp
src/driver/perf.cpp
+17
-0
src/driver/perf.hpp
src/driver/perf.hpp
+1
-0
src/generate.cpp
src/generate.cpp
+11
-0
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+10
-0
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+2
-2
No files found.
src/driver/argument_parser.hpp
View file @
eb0168fb
...
@@ -28,10 +28,34 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -28,10 +28,34 @@ inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DRIVER_STATIC static
#define MIGRAPHX_DRIVER_STATIC static
#endif
#endif
template
<
class
T
>
using
bare
=
std
::
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
namespace
detail
{
template
<
class
T
>
auto
is_container
(
int
,
T
&&
x
)
->
decltype
(
x
.
insert
(
x
.
end
(),
*
x
.
begin
()),
std
::
true_type
{});
template
<
class
T
>
std
::
false_type
is_container
(
float
,
T
&&
);
}
// namespace detail
template
<
class
T
>
struct
is_container
:
decltype
(
detail
::
is_container
(
int
(
0
),
std
::
declval
<
T
>
()))
{
};
template
<
class
T
>
using
is_multi_value
=
std
::
integral_constant
<
bool
,
(
is_container
<
T
>
{}
and
not
std
::
is_convertible
<
T
,
std
::
string
>
{})
>
;
template
<
class
T
>
template
<
class
T
>
struct
value_parser
struct
value_parser
{
{
template
<
MIGRAPHX_REQUIRES
(
not
std
::
is_enum
<
T
>{})
>
template
<
MIGRAPHX_REQUIRES
(
not
std
::
is_enum
<
T
>{}
and
not
is_multi_value
<
T
>
{}
)
>
static
T
apply
(
const
std
::
string
&
x
)
static
T
apply
(
const
std
::
string
&
x
)
{
{
T
result
;
T
result
;
...
@@ -43,7 +67,7 @@ struct value_parser
...
@@ -43,7 +67,7 @@ struct value_parser
return
result
;
return
result
;
}
}
template
<
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
template
<
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{}
and
not
is_multi_value
<
T
>
{}
)
>
static
T
apply
(
const
std
::
string
&
x
)
static
T
apply
(
const
std
::
string
&
x
)
{
{
std
::
ptrdiff_t
i
;
std
::
ptrdiff_t
i
;
...
@@ -54,6 +78,15 @@ struct value_parser
...
@@ -54,6 +78,15 @@ struct value_parser
throw
std
::
runtime_error
(
"Failed to parse: "
+
x
);
throw
std
::
runtime_error
(
"Failed to parse: "
+
x
);
return
static_cast
<
T
>
(
i
);
return
static_cast
<
T
>
(
i
);
}
}
template
<
MIGRAPHX_REQUIRES
(
is_multi_value
<
T
>{}
and
not
std
::
is_enum
<
T
>
{})
>
static
T
apply
(
const
std
::
string
&
x
)
{
T
result
;
using
value_type
=
typename
T
::
value_type
;
result
.
insert
(
result
.
end
(),
value_parser
<
value_type
>::
apply
(
x
));
return
result
;
}
};
};
struct
argument_parser
struct
argument_parser
...
@@ -69,6 +102,18 @@ struct argument_parser
...
@@ -69,6 +102,18 @@ struct argument_parser
unsigned
nargs
=
1
;
unsigned
nargs
=
1
;
};
};
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_multi_value
<
T
>{})
>
std
::
string
as_string_value
(
const
T
&
x
)
{
return
to_string_range
(
x
);
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_multi_value
<
T
>{})
>
std
::
string
as_string_value
(
const
T
&
x
)
{
return
to_string
(
x
);
}
template
<
class
T
,
class
...
Fs
>
template
<
class
T
,
class
...
Fs
>
void
operator
()(
T
&
x
,
const
std
::
vector
<
std
::
string
>&
flags
,
Fs
...
fs
)
void
operator
()(
T
&
x
,
const
std
::
vector
<
std
::
string
>&
flags
,
Fs
...
fs
)
{
{
...
@@ -81,7 +126,7 @@ struct argument_parser
...
@@ -81,7 +126,7 @@ struct argument_parser
argument
&
arg
=
arguments
.
back
();
argument
&
arg
=
arguments
.
back
();
arg
.
type
=
migraphx
::
get_type_name
<
T
>
();
arg
.
type
=
migraphx
::
get_type_name
<
T
>
();
arg
.
default_value
=
to
_string
(
x
);
arg
.
default_value
=
as
_string
_value
(
x
);
migraphx
::
each_args
([
&
](
auto
f
)
{
f
(
x
,
arg
);
},
fs
...);
migraphx
::
each_args
([
&
](
auto
f
)
{
f
(
x
,
arg
);
},
fs
...);
}
}
...
@@ -127,7 +172,7 @@ struct argument_parser
...
@@ -127,7 +172,7 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC
auto
append
()
MIGRAPHX_DRIVER_STATIC
auto
append
()
{
{
return
write_action
([](
auto
&
,
auto
&
x
,
auto
&
params
)
{
return
write_action
([](
auto
&
,
auto
&
x
,
auto
&
params
)
{
using
type
=
typename
decltype
(
params
)
::
value_type
;
using
type
=
typename
bare
<
decltype
(
params
)
>
::
value_type
;
std
::
transform
(
params
.
begin
(),
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
end
(),
std
::
inserter
(
x
,
x
.
end
()),
std
::
inserter
(
x
,
x
.
end
()),
...
...
src/driver/main.cpp
View file @
eb0168fb
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/eliminate_pad.hpp>
...
@@ -80,11 +81,13 @@ struct compiler
...
@@ -80,11 +81,13 @@ struct compiler
{
{
loader
l
;
loader
l
;
bool
gpu
=
true
;
bool
gpu
=
true
;
std
::
vector
<
std
::
string
>
fill1
;
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
l
.
parse
(
ap
);
l
.
parse
(
ap
);
ap
(
gpu
,
{
"--gpu"
},
ap
.
help
(
"Compile on the gpu"
),
ap
.
set_value
(
true
));
ap
(
gpu
,
{
"--gpu"
},
ap
.
help
(
"Compile on the gpu"
),
ap
.
set_value
(
true
));
ap
(
gpu
,
{
"--cpu"
},
ap
.
help
(
"Compile on the cpu"
),
ap
.
set_value
(
false
));
ap
(
gpu
,
{
"--cpu"
},
ap
.
help
(
"Compile on the cpu"
),
ap
.
set_value
(
false
));
ap
(
fill1
,
{
"--fill1"
},
ap
.
help
(
"Fill parameter with 1s"
),
ap
.
append
());
}
}
program
compile
()
program
compile
()
...
@@ -94,7 +97,14 @@ struct compiler
...
@@ -94,7 +97,14 @@ struct compiler
return
p
;
return
p
;
}
}
auto
params
(
const
program
&
p
)
{
return
create_param_map
(
p
,
gpu
);
}
auto
params
(
const
program
&
p
)
{
program
::
parameter_map
m
;
for
(
auto
&&
s
:
fill1
)
m
[
s
]
=
fill_argument
(
p
.
get_parameter_shape
(
s
),
1
);
fill_param_map
(
m
,
p
,
gpu
);
return
m
;
}
};
};
struct
read
:
command
<
read
>
struct
read
:
command
<
read
>
...
...
src/driver/perf.cpp
View file @
eb0168fb
...
@@ -11,6 +11,23 @@ namespace migraphx {
...
@@ -11,6 +11,23 @@ namespace migraphx {
namespace
driver
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
program
::
parameter_map
fill_param_map
(
program
::
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
)
{
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
argument
&
arg
=
m
[
x
.
first
];
if
(
arg
.
empty
())
arg
=
generate_argument
(
x
.
second
);
#ifdef HAVE_GPU
if
(
gpu
)
arg
=
gpu
::
to_gpu
(
arg
);
#else
(
void
)
gpu
;
#endif
}
return
m
;
}
program
::
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
)
program
::
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
)
{
{
program
::
parameter_map
m
;
program
::
parameter_map
m
;
...
...
src/driver/perf.hpp
View file @
eb0168fb
...
@@ -7,6 +7,7 @@ namespace migraphx {
...
@@ -7,6 +7,7 @@ namespace migraphx {
namespace
driver
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
program
::
parameter_map
fill_param_map
(
program
::
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
);
program
::
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
=
true
);
program
::
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
=
true
);
void
compile_program
(
program
&
p
,
bool
gpu
=
true
);
void
compile_program
(
program
&
p
,
bool
gpu
=
true
);
...
...
src/generate.cpp
View file @
eb0168fb
...
@@ -3,6 +3,17 @@
...
@@ -3,6 +3,17 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
argument
fill_argument
(
shape
s
,
unsigned
long
value
)
{
argument
result
;
s
.
visit_type
([
&
](
auto
as
)
{
using
type
=
typename
decltype
(
as
)
::
type
;
auto
v
=
fill_tensor_data
<
type
>
(
s
,
value
);
result
=
{
s
,
[
v
]()
mutable
{
return
reinterpret_cast
<
char
*>
(
v
.
data
());
}};
});
return
result
;
}
argument
generate_argument
(
shape
s
,
unsigned
long
seed
)
argument
generate_argument
(
shape
s
,
unsigned
long
seed
)
{
{
argument
result
;
argument
result
;
...
...
src/include/migraphx/generate.hpp
View file @
eb0168fb
...
@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
...
@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return
result
;
return
result
;
}
}
template
<
class
T
>
std
::
vector
<
T
>
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
=
]{
return
value
;
});
return
result
;
}
argument
fill_argument
(
shape
s
,
unsigned
long
value
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
...
...
src/include/migraphx/requires.hpp
View file @
eb0168fb
...
@@ -24,8 +24,8 @@ using bool_c = std::integral_constant<bool, B>;
...
@@ -24,8 +24,8 @@ using bool_c = std::integral_constant<bool, B>;
#define MIGRAPHX_REQUIRES(...) class = void
#define MIGRAPHX_REQUIRES(...) class = void
#else
#else
#define MIGRAPHX_REQUIRES(...) \
#define MIGRAPHX_REQUIRES(...) \
bool
MIGRAPHX_REQUIRES_VAR() =
true
, \
long
MIGRAPHX_REQUIRES_VAR() =
__LINE__
, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR()
== __LINE__
&& (migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0
int>::type = 0
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment