Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
32b1f924
Commit
32b1f924
authored
Feb 03, 2023
by
charlie
Browse files
progress
parent
0b0a6d4f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
64 deletions
+26
-64
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+26
-64
No files found.
src/include/migraphx/op/select_module.hpp
View file @
32b1f924
...
@@ -33,100 +33,62 @@ namespace migraphx {
...
@@ -33,100 +33,62 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
// Make this work just for exact matches
// can get rid of the other attributes and just check all the parameters are the same
// GPU version of this might have to deal with output parameters
// see loop op for how the output parameters are dealt with there
// Can have multiple inputs but only one output?
struct
select_module
struct
select_module
{
{
// output shape of the dynamic model
// output shape of the dynamic model
shape
output_dyn_shape
;
shape
output_dyn_shape
;
int
input_batch_index
=
-
1
;
int
output_batch_index
=
-
1
;
std
::
string
dyn_batch_param_name
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
output_dyn_shape
,
"output_dyn_shape"
),
return
pack
(
f
(
self
.
output_dyn_shape
,
"output_dyn_shape"
));
f
(
self
.
input_batch_index
,
"input_batch_index"
),
f
(
self
.
output_batch_index
,
"output_batch_index"
),
f
(
self
.
dyn_batch_param_name
,
"dyn_batch_param_name"
));
}
}
std
::
string
name
()
const
{
return
"select_module"
;
}
std
::
string
name
()
const
{
return
"select_module"
;
}
// runs once during model compilation with dynamic shape input
// may run on each model evaluation with static shape input
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
};
auto
s0
=
inputs
.
at
(
0
);
return
shape
{
output_dyn_shape
};
if
(
s0
.
dynamic
())
{
// should we check that the submodules have the same parameters here?
// check that no more than one parameter is non-fixed?
// would need to use version of compute_shape with the parameter list
return
shape
{
output_dyn_shape
};
}
else
{
auto
batch_size
=
s0
.
lens
().
at
(
input_batch_index
);
auto
dds
=
output_dyn_shape
.
dyn_dims
();
dds
.
at
(
output_batch_index
)
=
{
batch_size
,
batch_size
};
std
::
vector
<
std
::
size_t
>
dims
;
if
(
std
::
all_of
(
dds
.
begin
(),
dds
.
end
(),
[](
auto
dd
)
{
return
dd
.
is_fixed
();
}))
{
std
::
transform
(
dds
.
begin
(),
dds
.
end
(),
std
::
back_inserter
(
dims
),
[](
auto
d
)
{
return
d
.
max
;
});
return
{
output_dyn_shape
.
type
(),
dims
};
}
else
{
MIGRAPHX_THROW
(
"SELECT_MODULE: more than one input dimension was non-fixed"
);
}
}
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
module_ref
>&
submodule_list
,
const
std
::
vector
<
module_ref
>&
submodule_list
,
const
std
::
function
<
std
::
vector
<
argument
>
(
const
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>&
run
)
const
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>&
run
)
const
{
{
std
::
vector
<
module_ref
>
modules_to_run
;
// find submodule with parameter shapes exactly the same as the input arguments
for
(
const
auto
&
mod
:
submodule_list
)
// assuming arguments are in the same order as the parameters
{
auto
module_to_run
=
std
::
find_if
(
submodule_list
.
begin
(),
submodule_list
.
end
(),
[
&
](
module_ref
mr
)
{
// find submodule with the same parameter shape as the input data
auto
param_names
=
mr
.
get_parameter_names
();
auto
p_shape
=
mod
->
get_parameter_shape
(
dyn_batch_param_name
);
std
::
equal
(
args
.
cbegin
(),
args
.
cend
(),
param_names
.
cbegin
(),
[
&
](
auto
a
,
auto
p_name
)
{
if
(
p_shape
==
args
.
at
(
0
).
get_shape
())
return
a
.
get_shape
()
==
mr
.
get_parameter_shape
(
p_name
);
{
});
modules_to_run
.
push_back
(
mod
);
});
break
;
}
}
// TODO if an exact match is not found, assemble module list from binary base
if
(
module
s
_to_run
.
empty
())
if
(
module_to_run
==
submodule_list
.
end
())
{
{
MIGRAPHX_THROW
(
"SELECT_MODULE: no compatible submodules found for input shape: "
+
MIGRAPHX_THROW
(
"SELECT_MODULE: no compatible submodules found for given input shapes"
);
migraphx
::
to_string
(
args
.
at
(
0
).
get_shape
()));
}
}
std
::
set
<
std
::
string
>
pnames
;
for
(
const
auto
&
mod
:
modules_to_run
)
auto
param_names
=
module_to_run
.
get_parameter_names
();
{
// TODO If all the modules have the same parameters, this would only need to run once
auto
names
=
mod
->
get_parameter_names
();
pnames
.
insert
(
names
.
begin
(),
names
.
end
());
}
assert
(
pnames
.
size
()
<=
args
.
size
());
assert
(
pnames
.
size
()
<=
args
.
size
());
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
std
::
transform
(
pnames
.
begin
(),
std
::
transform
(
p
aram_
names
.
begin
(),
pnames
.
end
(),
p
aram_
names
.
end
(),
args
.
begin
(),
args
.
begin
(),
std
::
inserter
(
params
,
params
.
end
()),
std
::
inserter
(
params
,
params
.
end
()),
[](
auto
&&
name
,
auto
&&
arg
)
{
return
std
::
make_pair
(
name
,
arg
);
});
[](
auto
&&
name
,
auto
&&
arg
)
{
return
std
::
make_pair
(
name
,
arg
);
});
// TODO run multiple modules and split the parameter data to each batch size
auto
results
=
run
(
module_to_run
,
params
);
auto
results
=
run
(
modules_to_run
.
at
(
0
),
params
);
return
argument
{
results
};
return
results
.
at
(
0
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment