Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e297b13
Commit
7e297b13
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
86ea5e91
aa7ff911
Changes
765
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
345 additions
and
118 deletions
+345
-118
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+14
-1
src/include/migraphx/optional.hpp
src/include/migraphx/optional.hpp
+4
-1
src/include/migraphx/par_dfor.hpp
src/include/migraphx/par_dfor.hpp
+0
-1
src/include/migraphx/par_for.hpp
src/include/migraphx/par_for.hpp
+2
-2
src/include/migraphx/pass.hpp
src/include/migraphx/pass.hpp
+56
-23
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+15
-11
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+9
-1
src/include/migraphx/propagate_constant.hpp
src/include/migraphx/propagate_constant.hpp
+1
-1
src/include/migraphx/quantization.hpp
src/include/migraphx/quantization.hpp
+0
-22
src/include/migraphx/quantize_fp16.hpp
src/include/migraphx/quantize_fp16.hpp
+27
-0
src/include/migraphx/quantize_int8.hpp
src/include/migraphx/quantize_int8.hpp
+42
-0
src/include/migraphx/raw_data.hpp
src/include/migraphx/raw_data.hpp
+1
-2
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+1
-1
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+12
-13
src/include/migraphx/run_loop.hpp
src/include/migraphx/run_loop.hpp
+115
-0
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-1
src/include/migraphx/schedule_model.hpp
src/include/migraphx/schedule_model.hpp
+37
-31
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+0
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+7
-5
No files found.
src/include/migraphx/operators.hpp
View file @
7e297b13
...
...
@@ -35,12 +35,14 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
...
...
@@ -49,6 +51,7 @@
#include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
...
...
@@ -56,6 +59,8 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
...
...
@@ -78,10 +83,16 @@
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
...
...
@@ -95,11 +106,13 @@
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif
src/include/migraphx/optional.hpp
View file @
7e297b13
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
...
...
src/include/migraphx/par_dfor.hpp
View file @
7e297b13
...
...
@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{
dfor
(
xs
...)(
f
);
}
};
}
...
...
src/include/migraphx/par_for.hpp
View file @
7e297b13
...
...
@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
{
const
auto
threadsize
=
std
::
m
in
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
min_grain
);
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
std
::
m
ax
<
std
::
size_t
>
(
1
,
min_grain
)
)
;
par_for_impl
(
n
,
threadsize
,
f
);
}
...
...
src/include/migraphx/pass.hpp
View file @
7e297b13
...
...
@@ -8,12 +8,14 @@
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
struct
module
;
struct
module_pass_manager
;
#ifdef DOXYGEN
...
...
@@ -24,6 +26,7 @@ struct pass
/// A unique name used to identify the pass
std
::
string
name
()
const
;
/// Run the pass on the module
void
apply
(
module_pass_manager
&
mpm
)
const
;
void
apply
(
module
&
m
)
const
;
/// Run the pass on the program
void
apply
(
program
&
p
)
const
;
...
...
@@ -31,17 +34,44 @@ struct pass
#else
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(module & m) const;
* void apply(program & p) const;
* };
*
*/
module
&
get_module
(
module_pass_manager
&
mpm
);
namespace
detail
{
template
<
class
T
>
auto
module_pass_manager_apply
(
rank
<
1
>
,
const
T
&
x
,
module_pass_manager
&
mpm
)
->
decltype
(
x
.
apply
(
get_module
(
mpm
)))
{
return
x
.
apply
(
get_module
(
mpm
));
}
template
<
class
T
>
void
module_pass_manager_apply
(
rank
<
0
>
,
const
T
&
,
module_pass_manager
&
)
{
}
template
<
class
T
>
void
module_pass_manager_apply
(
const
T
&
x
,
module_pass_manager
&
mpm
)
{
module_pass_manager_apply
(
rank
<
1
>
{},
x
,
mpm
);
}
}
// namespace detail
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
pass
{
//
std
::
string
name
()
const
;
// (optional)
void
apply
(
module_pass_manager
&
mpm
)
const
;
// (optional)
void
apply
(
program
&
p
)
const
;
};
#else
struct
pass
{
...
...
@@ -112,10 +142,10 @@ struct pass
return
(
*
this
).
private_detail_te_get_handle
().
name
();
}
void
apply
(
module
&
m
)
const
void
apply
(
module
_pass_manager
&
mp
m
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
m
);
(
*
this
).
private_detail_te_get_handle
().
apply
(
m
pm
);
}
void
apply
(
program
&
p
)
const
...
...
@@ -137,22 +167,24 @@ struct pass
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
module
&
m
)
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
void
apply
(
module
_pass_manager
&
mp
m
)
const
=
0
;
virtual
void
apply
(
program
&
p
)
const
=
0
;
};
template
<
class
T
>
static
auto
private_detail_te_default_apply
(
char
,
T
&&
private_detail_te_self
,
module
&
m
)
->
decltype
(
private_detail_te_self
.
apply
(
m
))
static
auto
private_detail_te_default_apply
(
char
,
T
&&
private_detail_te_self
,
module_pass_manager
&
mpm
)
->
decltype
(
private_detail_te_self
.
apply
(
mpm
))
{
private_detail_te_self
.
apply
(
m
);
private_detail_te_self
.
apply
(
m
pm
);
}
template
<
class
T
>
static
void
private_detail_te_default_apply
(
float
,
T
&&
private_detail_te_self
,
module
&
m
)
static
void
private_detail_te_default_apply
(
float
,
T
&&
private_detail_te_self
,
module_pass_manager
&
mpm
)
{
migraphx
::
nop
(
private_detail_te_self
,
m
);
migraphx
::
detail
::
module_pass_manager_apply
(
private_detail_te_self
,
mp
m
);
}
template
<
class
T
>
...
...
@@ -198,10 +230,10 @@ struct pass
std
::
string
name
()
const
override
{
return
private_detail_te_value
.
name
();
}
void
apply
(
module
&
m
)
const
override
void
apply
(
module
_pass_manager
&
mp
m
)
const
override
{
private_detail_te_default_apply
(
char
(
0
),
private_detail_te_value
,
m
);
private_detail_te_default_apply
(
char
(
0
),
private_detail_te_value
,
mp
m
);
}
void
apply
(
program
&
p
)
const
override
...
...
@@ -274,6 +306,7 @@ inline const ValueType& any_cast(const pass& x)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
#endif
...
...
src/include/migraphx/pass_manager.hpp
100644 → 100755
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
#include <migraphx/pass.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
{
module_pass_manager
()
=
default
;
module_pass_manager
(
const
module_pass_manager
&
)
=
delete
;
virtual
module
&
get_module
()
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
protected:
virtual
~
module_pass_manager
()
{}
};
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
...
...
src/include/migraphx/program.hpp
View file @
7e297b13
...
...
@@ -23,6 +23,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct
program_impl
;
struct
marker
;
/**
* @brief Stores the instruction stream
*/
...
...
@@ -65,7 +67,10 @@ struct program
void
finalize
();
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
)
const
;
void
perf_report
(
std
::
ostream
&
os
,
std
::
size_t
n
,
parameter_map
params
,
std
::
size_t
batch
=
1
)
const
;
void
mark
(
const
parameter_map
&
params
,
marker
&&
m
);
value
to_value
()
const
;
void
from_value
(
const
value
&
v
);
...
...
@@ -76,6 +81,9 @@ struct program
const
std
::
function
<
void
(
instruction_ref
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
)
>&
print_func
)
const
;
void
print
(
const
std
::
function
<
void
(
instruction_ref
ins
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
)
>&
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
...
...
src/include/migraphx/propagate_constant.hpp
View file @
7e297b13
...
...
@@ -15,7 +15,7 @@ struct module;
struct
propagate_constant
{
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/quantization.hpp
View file @
7e297b13
...
...
@@ -17,32 +17,10 @@ struct program;
void
quantize_fp16
(
program
&
prog
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"all"
});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std
::
size_t
capture_arguments
(
program
&
prog
,
const
std
::
vector
<
std
::
string
>&
ins_names
,
const
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>&
func
);
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
capture_arguments_impl
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
);
template
<
class
T
>
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
capture_arguments
(
program
&
prog
,
T
&&
t
,
const
std
::
vector
<
std
::
string
>&
ins_names
)
{
static_assert
(
std
::
is_same
<
std
::
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
,
target
>
{}
&&
std
::
is_lvalue_reference
<
T
>
{},
"Dangling reference to target!"
);
return
capture_arguments_impl
(
prog
,
t
,
ins_names
);
}
void
quantize_int8
(
program
&
prog
,
const
target
&
t
,
const
std
::
vector
<
parameter_map
>&
calibration
,
const
std
::
vector
<
std
::
string
>&
ins_names
=
{
"dot"
,
"convolution"
});
void
quantize_int8_impl
(
program
&
prog
,
const
std
::
vector
<
std
::
pair
<
float
,
float
>>&
quant_params
,
const
std
::
vector
<
std
::
string
>&
ins_names
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/
remap
.hpp
→
src/include/migraphx/
quantize_fp16
.hpp
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_RTGLIB_
REMAP
_HPP
#define MIGRAPHX_GUARD_RTGLIB_
REMAP
_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_
QUANTIZE_FP16
_HPP
#define MIGRAPHX_GUARD_RTGLIB_
QUANTIZE_FP16
_HPP
#include <string>
#include <
migraphx/instruction_ref.hpp
>
#include <
vector
>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
struct
module
;
/**
*
Decompose operators.
*
quantize a program to fp16
*/
struct
remap
struct
quantize_fp16_pass
{
std
::
string
name
()
const
{
return
"remap"
;
}
void
apply
(
module
&
p
)
const
;
std
::
vector
<
std
::
string
>
ins_names
=
{
"all"
};
std
::
string
name
()
const
{
return
"quantize_fp16"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/quantize_int8.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
struct
module
;
/**
* capture inputs of operators to be quantized to int8
*/
struct
capture_arguments_pass
{
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
function
<
void
(
std
::
size_t
,
std
::
vector
<
argument
>
)
>
f
{};
std
::
size_t
*
param_index
=
nullptr
;
std
::
string
name
()
const
{
return
"capture_arguments"
;
}
void
apply
(
module
&
m
)
const
;
};
/**
* quantize a program to int8
*/
struct
quantize_int8_pass
{
std
::
vector
<
std
::
string
>
ins_names
=
{
"dot"
,
"convolution"
};
std
::
vector
<
std
::
pair
<
float
,
float
>>
quant_params
;
std
::
string
name
()
const
{
return
"quantize_int8"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/raw_data.hpp
View file @
7e297b13
...
...
@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1)
template
<
class
T
,
class
...
Ts
>
auto
visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
auto
&&
s
=
x
.
get_shape
();
// cppcheck-suppress redundantInitialization
auto
&&
s
=
x
.
get_shape
();
std
::
initializer_list
<
shape
::
type_t
>
types
=
{
xs
.
get_shape
().
type
()...};
if
(
!
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
...
...
src/include/migraphx/rewrite_batchnorm.hpp
View file @
7e297b13
...
...
@@ -16,7 +16,7 @@ struct module;
struct
rewrite_batchnorm
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_pooling.hpp
View file @
7e297b13
...
...
@@ -15,7 +15,7 @@ struct module;
struct
rewrite_pooling
{
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
7e297b13
...
...
@@ -19,22 +19,22 @@ struct module;
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_vanilla_rnn
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_gru
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
...
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_lstm
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
...
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
bool
is_variable_seq_lens
(
const
module
&
prog
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
prog
,
bool
is_variable_seq_lens
(
const
module
&
m
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
last_hs_output
,
op
::
rnn_direction
dirct
)
const
;
void
replace_last_cell_output
(
module
&
prog
,
void
replace_last_cell_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
cell_outputs
,
instruction_ref
last_cell_output
,
op
::
rnn_direction
dirct
)
const
;
std
::
size_t
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
std
::
size_t
get_seq_len
(
const
module
&
m
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
instruction_ref
pad_hidden_states
(
module
&
prog
,
instruction_ref
pad_hidden_states
(
module
&
m
,
instruction_ref
seq
,
instruction_ref
seq_lens
,
instruction_ref
hs
)
const
;
...
...
src/include/migraphx/run_loop.hpp
0 → 100644
View file @
7e297b13
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
LoopModel
,
class
T
>
argument
run_loop
(
const
LoopModel
&
model
,
T
&
ctx
,
std
::
vector
<
argument
>
args
,
const
std
::
vector
<
module_ref
>&
mods
,
const
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>&
run
)
{
std
::
vector
<
std
::
vector
<
argument
>>
results
;
// process argu lists
auto
iter_num
=
args
.
at
(
0
).
at
<
int64_t
>
();
auto
cond
=
args
.
at
(
1
).
at
<
bool
>
();
auto
input_num
=
(
args
.
size
()
-
2
)
/
2
;
auto
dep_num
=
input_num
-
2
;
module_ref
mod
=
mods
.
at
(
0
);
auto
param_name_shapes
=
mod
->
get_parameter_shapes
();
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
argument
>
dep0
(
args
.
begin
()
+
input_num
+
1
,
args
.
begin
()
+
2
*
input_num
);
std
::
vector
<
argument
>
dep1
(
args
.
begin
()
+
2
*
input_num
,
args
.
begin
()
+
2
*
input_num
+
1
);
auto
ins_outputs
=
args
.
back
().
get_sub_objects
();
dep1
.
insert
(
dep1
.
end
(),
ins_outputs
.
begin
(),
ins_outputs
.
begin
()
+
dep_num
);
std
::
array
<
std
::
vector
<
argument
>
,
2
>
loop_carry_deps
=
{
dep0
,
dep1
};
// loop iter argument
std
::
vector
<
argument
>
in_args
=
{
args
.
at
(
input_num
),
dep1
.
at
(
0
)};
in_args
.
insert
(
in_args
.
end
(),
args
.
begin
()
+
2
,
args
.
begin
()
+
input_num
);
std
::
vector
<
argument
>
out_args
=
dep0
;
out_args
.
insert
(
out_args
.
end
(),
ins_outputs
.
begin
()
+
dep_num
,
ins_outputs
.
end
());
std
::
vector
<
argument
>
scan_outputs
(
ins_outputs
.
begin
()
+
dep_num
,
ins_outputs
.
end
());
auto
out_param_indices
=
model
.
get_output_params
(
*
mod
);
int64_t
iter
=
0
;
for
(
iter
=
0
;
iter
<
iter_num
and
cond
;
++
iter
)
{
// copy iter num and cond to device memory
model
.
copy
(
ctx
,
iter
,
in_args
.
at
(
0
));
model
.
copy
(
ctx
,
cond
,
in_args
.
at
(
1
));
// wrap up the inputs and outputs
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
int
input_index
=
0
;
for
(
const
auto
&
name
:
param_names
)
{
auto
ps
=
mod
->
get_parameter_shape
(
name
);
if
(
ps
==
shape
{})
{
continue
;
}
// it is an input parameter
if
(
not
contains
(
out_param_indices
,
name
))
{
params
[
name
]
=
in_args
.
at
(
input_index
++
);
}
else
{
auto
output_index
=
out_param_indices
[
name
];
if
(
output_index
>
dep_num
)
{
const
auto
&
arg
=
out_args
.
at
(
output_index
);
assert
((
iter
+
1
)
*
ps
.
bytes
()
<=
arg
.
get_shape
().
bytes
());
params
[
name
]
=
argument
(
ps
,
arg
.
data
()
+
iter
*
ps
.
bytes
());
}
else
{
params
[
name
]
=
out_args
.
at
(
output_index
);
}
}
}
auto
mod_args
=
run
(
mod
,
params
);
// copy back cond to be used next iteration
model
.
copy
(
ctx
,
mod_args
.
at
(
0
),
cond
);
// mod outputs are used as next loop input
std
::
copy
(
mod_args
.
begin
(),
mod_args
.
begin
()
+
dep_num
+
1
,
in_args
.
begin
()
+
1
);
const
auto
&
dep_out
=
loop_carry_deps
[(
iter
+
1
)
%
2
];
std
::
copy
(
dep_out
.
begin
(),
dep_out
.
end
(),
out_args
.
begin
());
std
::
vector
<
argument
>
mod_scan_outs
(
mod_args
.
begin
()
+
1
+
dep_num
,
mod_args
.
end
());
model
.
append
(
mod_scan_outs
,
scan_outputs
,
iter
);
}
out_args
.
erase
(
out_args
.
begin
());
std
::
copy
(
in_args
.
begin
()
+
2
,
in_args
.
end
(),
out_args
.
begin
());
model
.
set_zero
(
ctx
,
scan_outputs
,
iter
);
return
{
out_args
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/schedule.hpp
View file @
7e297b13
...
...
@@ -19,7 +19,7 @@ struct schedule
schedule_model
model
{};
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/schedule_model.hpp
View file @
7e297b13
...
...
@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
std
::
size_t
concurrency
()
const
;
/// Schedule a concurrent instruction
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
// Insert necessary waits before an instruction
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
// Insert necessary records after an instruction
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
/// Compute weights for an operation
std
::
size_t
weight
(
const
operation
&
op
)
const
;
};
#else
/*
* Type-erased interface for:
*
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
schedule_model
{
//
std
::
size_t
concurrency
()
const
;
//
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
//
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
//
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
//
std
::
size_t
weight
(
const
operation
&
op
)
const
;
};
#else
struct
schedule_model
{
...
...
@@ -120,22 +125,22 @@ struct schedule_model
return
(
*
this
).
private_detail_te_get_handle
().
concurrency
();
}
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
sched
(
p
,
ins
,
n
);
(
*
this
).
private_detail_te_get_handle
().
sched
(
m
,
ins
,
n
);
}
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait
(
p
,
ins
,
wait_id
);
(
*
this
).
private_detail_te_get_handle
().
wait
(
m
,
ins
,
wait_id
);
}
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
record
(
p
,
ins
,
wait_id
);
(
*
this
).
private_detail_te_get_handle
().
record
(
m
,
ins
,
wait_id
);
}
std
::
size_t
weight
(
const
operation
&
op
)
const
...
...
@@ -159,9 +164,9 @@ struct schedule_model
virtual
const
std
::
type_info
&
type
()
const
=
0
;
virtual
std
::
size_t
concurrency
()
const
=
0
;
virtual
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
=
0
;
virtual
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
=
0
;
virtual
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
=
0
;
virtual
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
=
0
;
virtual
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
=
0
;
virtual
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
=
0
;
virtual
std
::
size_t
weight
(
const
operation
&
op
)
const
=
0
;
};
...
...
@@ -195,22 +200,22 @@ struct schedule_model
std
::
size_t
concurrency
()
const
override
{
return
private_detail_te_value
.
concurrency
();
}
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
override
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
override
{
private_detail_te_value
.
sched
(
p
,
ins
,
n
);
private_detail_te_value
.
sched
(
m
,
ins
,
n
);
}
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
override
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
override
{
private_detail_te_value
.
wait
(
p
,
ins
,
wait_id
);
private_detail_te_value
.
wait
(
m
,
ins
,
wait_id
);
}
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
override
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
override
{
private_detail_te_value
.
record
(
p
,
ins
,
wait_id
);
private_detail_te_value
.
record
(
m
,
ins
,
wait_id
);
}
std
::
size_t
weight
(
const
operation
&
op
)
const
override
...
...
@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
#endif
...
...
src/include/migraphx/serialize.hpp
View file @
7e297b13
...
...
@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
value
result
=
value
::
array
{};
for
(
auto
&&
y
:
x
)
{
auto
e
=
to_value
(
y
);
result
.
insert
(
to_value
(
y
));
}
return
result
;
...
...
src/include/migraphx/shape.hpp
100755 → 100644
View file @
7e297b13
...
...
@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum
type_t
...
...
@@ -131,6 +131,8 @@ struct shape
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_lens
(
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
shape
with_type
(
type_t
t
)
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
...
...
@@ -186,8 +188,7 @@ struct shape
{
switch
(
t
)
{
case
tuple_type
:
{
case
tuple_type
:
{
tv
();
return
;
}
...
...
@@ -224,10 +225,11 @@ struct shape
const
std
::
vector
<
shape
>&
sub_shapes
()
const
;
std
::
size_t
element_space
()
const
;
private:
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
);
std
::
shared_ptr
<
const
shape_impl
>
impl
;
std
::
size_t
element_space
()
const
;
};
void
migraphx_to_value
(
value
&
v
,
const
shape
&
s
);
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment