Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
399eacef
Commit
399eacef
authored
Jan 25, 2023
by
Paul
Browse files
Improve making the args
parent
8423d683
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
9 deletions
+12
-9
src/cpp_generator.cpp
src/cpp_generator.cpp
+3
-3
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+7
-4
src/include/migraphx/cpp_generator.hpp
src/include/migraphx/cpp_generator.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+1
-1
No files found.
src/cpp_generator.cpp
View file @
399eacef
...
...
@@ -106,10 +106,10 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return
*
this
;
}
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
name
)
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
p
name
)
{
params
.
push_back
({
name
,
"T"
+
name
});
tparams
.
push_back
(
"class T"
+
name
);
params
.
push_back
({
p
name
,
"T"
+
p
name
});
tparams
.
push_back
(
"class T"
+
p
name
);
return
*
this
;
}
...
...
src/fuse_reduce.cpp
View file @
399eacef
...
...
@@ -135,7 +135,7 @@ insert_module_in_submodule(module_ref sm,
}
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
find_inputs
(
module_ref
sm
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
...
...
@@ -145,6 +145,8 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
continue
;
if
(
param
->
name
()
!=
"@param"
)
continue
;
if
(
not
parent
.
has_instruction
(
input
))
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
...
...
@@ -152,6 +154,7 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
return
p
.
second
;
});
assert
(
result
.
size
()
==
sm
->
get_parameter_shapes
().
size
());
return
result
;
}
...
...
@@ -211,7 +214,7 @@ struct find_pointwise_reduce
// Insert fused_reduce
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
...
...
@@ -266,7 +269,7 @@ struct find_reduce_pointwise
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
...
...
@@ -300,7 +303,7 @@ struct find_reduce_reduce
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
...
...
src/include/migraphx/cpp_generator.hpp
View file @
399eacef
...
...
@@ -77,7 +77,7 @@ struct cpp_generator
function
&
set_types
(
const
module
&
m
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
add_generic_param
(
const
std
::
string
&
name
);
function
&
add_generic_param
(
const
std
::
string
&
p
name
);
};
cpp_generator
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
399eacef
...
...
@@ -180,7 +180,7 @@ struct index
}
else
{
static_assert
(
max_stride_iterations
(
n
,
stride
)
<
64
);
//
static_assert(max_stride_iterations(n, stride) < 64);
sequence
(
max_stride_iterations
(
n
,
stride
),
[
&
](
auto
...
ks
)
{
fold
([
&
](
auto
d
,
auto
k
)
{
auto
i
=
start
+
stride
*
k
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment