Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
82fee1e7
Commit
82fee1e7
authored
May 15, 2019
by
Shucai Xiao
Browse files
temp code backup
parent
420d2363
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
35 deletions
+59
-35
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+17
-11
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+3
-1
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+21
-16
src/quantization.cpp
src/quantization.cpp
+18
-7
No files found.
src/include/migraphx/op/binary.hpp
View file @
82fee1e7
...
@@ -28,23 +28,29 @@ struct binary : op_name<Derived>
...
@@ -28,23 +28,29 @@ struct binary : op_name<Derived>
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
auto
s1
=
args
[
0
].
get_shape
();
if
(
input1
.
get_shape
().
packed
()
and
input2
.
get_shape
().
packed
())
auto
s2
=
args
[
1
].
get_shape
();
{
if
(
s1
==
s2
and
s1
.
packed
())
std
::
transform
(
input1
.
begin
(),
{
input1
.
end
(),
shape
std_shape
{
s1
.
type
(),
s1
.
lens
()};
input2
.
begin
(),
auto
input1
=
make_view
(
std_shape
,
args
[
0
].
data
());
output
.
begin
(),
auto
input2
=
make_view
(
std_shape
,
args
[
1
].
data
());
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
auto
output
=
make_view
(
std_shape
,
result
.
data
());
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
input2
.
begin
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
}
}
else
else
{
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input1
(
idx
.
begin
(),
idx
.
end
()),
input2
(
idx
.
begin
(),
idx
.
end
()));
input1
(
idx
.
begin
(),
idx
.
end
()),
input2
(
idx
.
begin
(),
idx
.
end
()));
});
});
}
}
);
}
);
}
return
result
;
return
result
;
}
}
...
...
src/include/migraphx/op/convert.hpp
View file @
82fee1e7
...
@@ -26,7 +26,9 @@ struct convert : unary<convert>
...
@@ -26,7 +26,9 @@ struct convert : unary<convert>
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
target_type
,
"target_type"
));
return
pack
(
f
(
self
.
target_type
,
"target_type"
),
f
(
self
.
scale
,
"scale"
),
f
(
self
.
shift
,
"shift"
));
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
src/include/migraphx/op/unary.hpp
View file @
82fee1e7
...
@@ -27,26 +27,31 @@ struct unary : op_name<Derived>
...
@@ -27,26 +27,31 @@ struct unary : op_name<Derived>
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
output
)
{
auto
in_shape
=
args
[
0
].
get_shape
();
args
[
0
].
visit
([
&
](
auto
input
)
{
if
(
in_shape
.
packed
())
if
(
input
.
get_shape
().
packed
())
{
{
shape
std_in_shape
{
in_shape
.
type
(),
in_shape
.
lens
()};
std
::
transform
(
input
.
begin
(),
shape
std_out_shape
{
output_shape
.
type
(),
output_shape
.
lens
()};
input
.
end
(),
auto
input
=
make_view
(
std_in_shape
,
args
[
0
].
cast
());
output
.
begin
(),
auto
output
=
make_view
(
std_out_shape
,
result
.
cast
());
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
std
::
transform
(
input
.
begin
(),
input
.
end
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
}
else
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input
(
idx
.
begin
(),
idx
.
end
()));
});
return
result
;
return
result
;
}
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
static_cast
<
const
Derived
&>
(
*
this
).
apply
()(
input
(
idx
.
begin
(),
idx
.
end
()));
});
});
return
result
;
});
});
}
);
}
return
result
;
return
result
;
}
}
...
...
src/quantization.cpp
View file @
82fee1e7
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
...
@@ -197,7 +198,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
...
@@ -197,7 +198,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
}
}
auto
op
=
ins
->
get_operator
();
auto
op
=
ins
->
get_operator
();
auto
ins_shape
=
compute_shape
(
op
,
converted_inputs
);
shape
ins_shape
{};
// just to compute the output shape
if
(
ins
->
name
()
==
"dot"
)
{
ins_shape
=
compute_shape
(
op
::
quant_dot
{},
converted_inputs
);
}
else
{
ins_shape
=
compute_shape
(
op
::
quant_convolution
{},
converted_inputs
);
}
if
(
ins_shape
.
type
()
!=
orig_type
)
if
(
ins_shape
.
type
()
!=
orig_type
)
{
{
// check the dead code case to avoid assert
// check the dead code case to avoid assert
...
@@ -239,17 +250,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
...
@@ -239,17 +250,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
ins
,
ins
,
op
::
quant_convolution
{
padding
,
stride
,
dilation
,
padding_mode
,
group
},
op
::
quant_convolution
{
padding
,
stride
,
dilation
,
padding_mode
,
group
},
converted_inputs
);
converted_inputs
);
auto
conv_lens
=
conv_res
->
get_shape
().
lens
();
auto
conv_s
=
conv_res
->
get_shape
();
auto
fl
=
prog
.
add_literal
(
literal
(
adjust_factor
));
std
::
vector
<
float
>
vec_fact
(
conv_s
.
elements
(),
adjust_factor
);
auto
adj_fact
=
prog
.
insert_instruction
(
ins
,
op
::
multibroadcast
{
conv_lens
},
fl
);
prog
.
replace_instruction
(
ins
,
adj_fact
);
auto
fl
=
prog
.
add_literal
(
literal
{
conv_s
,
vec_fact
});
auto
ad_res
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
conv_res
,
fl
);
prog
.
replace_instruction
(
ins
,
ad_res
);
}
}
else
else
{
{
MIGRAPHX_THROW
(
"INT8_QUANTIZE: does not support operator"
+
ins
->
name
());
MIGRAPHX_THROW
(
"INT8_QUANTIZE: does not support operator"
+
ins
->
name
());
}
}
prog
.
replace_instruction
(
ins
,
op
,
converted_inputs
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment