Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
41ed1924
Commit
41ed1924
authored
Aug 28, 2019
by
Shucai Xiao
Browse files
refine int8 quantization APIs
parent
318dbc15
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
29 additions
and
37 deletions
+29
-37
src/include/migraphx/quantization.hpp
src/include/migraphx/quantization.hpp
+9
-2
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+4
-10
src/quantization.cpp
src/quantization.cpp
+2
-9
src/targets/cpu/include/migraphx/cpu/target.hpp
src/targets/cpu/include/migraphx/cpu/target.hpp
+3
-1
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+2
-1
test/type_conversion.cpp
test/type_conversion.cpp
+2
-1
tools/include/target.hpp
tools/include/target.hpp
+7
-13
No files found.
src/include/migraphx/quantization.hpp
View file @
41ed1924
...
...
@@ -22,10 +22,17 @@ void quantize(program& prog);
std
::
size_t
capture_arguments
(
program
&
prog
,
const
std
::
vector
<
std
::
string
>&
ins_names
,
const
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>&
func
);
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
capture_arguments
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
);
capture_arguments_impl
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
});
template
<
class
T
>
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
capture_arguments
(
program
&
prog
,
const
target
&
t
);
T
&&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
})
{
static_assert
(
std
::
is_same
<
std
::
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
,
target
>
{}
&&
std
::
is_lvalue_reference
<
T
>
{},
"Dangling reference to target!"
);
return
capture_arguments_impl
(
prog
,
t
,
ins_names
);
}
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
...
...
src/include/migraphx/target.hpp
View file @
41ed1924
...
...
@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
return
argument
{};
}
template
<
class
T
>
...
...
@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar
}
template
<
class
T
>
argument
copy_to_target
(
rank
<
0
>
,
T
&
x
,
const
argument
&
)
argument
copy_to_target
(
rank
<
0
>
,
T
&
,
const
argument
&
arg
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
return
argument
{};
return
arg
;
}
template
<
class
T
>
...
...
@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro
}
template
<
class
T
>
argument
copy_from_target
(
rank
<
0
>
,
T
&
x
,
const
argument
&
)
argument
copy_from_target
(
rank
<
0
>
,
T
&
,
const
argument
&
arg
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
return
argument
{};
return
arg
;
}
template
<
class
T
>
...
...
src/quantization.cpp
View file @
41ed1924
...
...
@@ -502,7 +502,7 @@ std::size_t capture_arguments(program& prog,
}
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
capture_arguments
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
capture_arguments
_impl
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
int8_quant_params
=
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
();
...
...
@@ -515,7 +515,7 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std
::
vector
<
float
>
vec_val
;
a
uto
&&
arg
=
t
.
copy_from
(
args
.
front
());
a
rgument
arg
=
t
.
copy_from
(
args
.
front
());
arg
.
visit
([
&
](
auto
output
)
{
vec_val
.
assign
(
output
.
begin
(),
output
.
end
());
});
auto
max_val
=
*
std
::
max_element
(
vec_val
.
begin
(),
vec_val
.
end
());
auto
min_val
=
*
std
::
min_element
(
vec_val
.
begin
(),
vec_val
.
end
());
...
...
@@ -534,12 +534,5 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
return
int8_quant_params
;
}
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
capture_arguments
(
program
&
prog
,
const
target
&
t
)
{
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
return
capture_arguments
(
prog
,
t
,
ins_names
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/cpu/include/migraphx/cpu/target.hpp
View file @
41ed1924
...
...
@@ -17,7 +17,9 @@ struct target
migraphx
::
context
get_context
()
const
{
return
context
{};
}
argument
copy_to
(
const
argument
&
arg
)
const
{
return
std
::
move
(
arg
);
}
argument
copy_from
(
const
argument
&
arg
)
const
{
return
std
::
move
(
arg
);
}
argument
copy_from
(
const
argument
&
arg
)
const
{
return
arg
;
}
argument
allocate
(
const
shape
&
s
)
const
;
};
...
...
test/cpu_ops_test.cpp
View file @
41ed1924
...
...
@@ -2067,7 +2067,8 @@ TEST_CASE(op_capture)
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
pa
,
ps
);
migraphx
::
program
capture_p
=
p
;
migraphx
::
capture_arguments
(
capture_p
);
migraphx
::
target
t
=
migraphx
::
cpu
::
target
{};
migraphx
::
capture_arguments
(
capture_p
,
t
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
capture_p
.
compile
(
migraphx
::
cpu
::
target
{});
...
...
test/type_conversion.cpp
View file @
41ed1924
...
...
@@ -248,7 +248,8 @@ TEST_CASE(op_capture)
{
auto
p
=
create_program_float
();
auto
op_capture_p
=
create_program_op
();
migraphx
::
capture_arguments
(
p
);
migraphx
::
target
t
=
migraphx
::
cpu
::
target
{};
migraphx
::
capture_arguments
(
p
,
t
);
EXPECT
(
p
==
op_capture_p
);
}
}
...
...
tools/include/target.hpp
View file @
41ed1924
...
...
@@ -72,7 +72,6 @@ argument target_allocate(rank<0>, T& x, const shape&)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
return
argument
{};
}
template
<
class
T
>
...
...
@@ -88,12 +87,9 @@ auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(ar
}
template
<
class
T
>
argument
copy_to_target
(
rank
<
0
>
,
T
&
x
,
const
argument
&
)
argument
copy_to_target
(
rank
<
0
>
,
T
&
,
const
argument
&
arg
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
return
argument
{};
return
arg
;
}
template
<
class
T
>
...
...
@@ -109,11 +105,9 @@ auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_fro
}
template
<
class
T
>
argument
copy_from_target
(
rank
<
0
>
,
T
&
x
,
const
argument
&
)
argument
copy_from_target
(
rank
<
0
>
,
T
&
,
const
argument
&
arg
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
return
argument
{};
return
arg
;
}
template
<
class
T
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment