Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
399eacef
Commit
399eacef
authored
Jan 25, 2023
by
Paul
Browse files
Improve making the args
parent
8423d683
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
9 deletions
+12
-9
src/cpp_generator.cpp
src/cpp_generator.cpp
+3
-3
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+7
-4
src/include/migraphx/cpp_generator.hpp
src/include/migraphx/cpp_generator.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+1
-1
No files found.
src/cpp_generator.cpp
View file @
399eacef
...
@@ -106,10 +106,10 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
...
@@ -106,10 +106,10 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return
*
this
;
return
*
this
;
}
}
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
name
)
cpp_generator
::
function
&
cpp_generator
::
function
::
add_generic_param
(
const
std
::
string
&
p
name
)
{
{
params
.
push_back
({
name
,
"T"
+
name
});
params
.
push_back
({
p
name
,
"T"
+
p
name
});
tparams
.
push_back
(
"class T"
+
name
);
tparams
.
push_back
(
"class T"
+
p
name
);
return
*
this
;
return
*
this
;
}
}
...
...
src/fuse_reduce.cpp
View file @
399eacef
...
@@ -135,7 +135,7 @@ insert_module_in_submodule(module_ref sm,
...
@@ -135,7 +135,7 @@ insert_module_in_submodule(module_ref sm,
}
}
static
std
::
vector
<
instruction_ref
>
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
find_inputs
(
module_ref
sm
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
std
::
vector
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
result
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
std
::
map
<
std
::
string
,
instruction_ref
>
names
;
...
@@ -145,6 +145,8 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
...
@@ -145,6 +145,8 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
continue
;
continue
;
if
(
param
->
name
()
!=
"@param"
)
if
(
param
->
name
()
!=
"@param"
)
continue
;
continue
;
if
(
not
parent
.
has_instruction
(
input
))
continue
;
auto
v
=
param
->
get_operator
().
to_value
();
auto
v
=
param
->
get_operator
().
to_value
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
auto
name
=
v
.
at
(
"parameter"
).
to
<
std
::
string
>
();
names
[
name
]
=
input
;
names
[
name
]
=
input
;
...
@@ -152,6 +154,7 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
...
@@ -152,6 +154,7 @@ find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
auto
&
p
)
{
return
p
.
second
;
return
p
.
second
;
});
});
assert
(
result
.
size
()
==
sm
->
get_parameter_shapes
().
size
());
return
result
;
return
result
;
}
}
...
@@ -211,7 +214,7 @@ struct find_pointwise_reduce
...
@@ -211,7 +214,7 @@ struct find_pointwise_reduce
// Insert fused_reduce
// Insert fused_reduce
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
insert_module_in_submodule
(
rm
,
reduce
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
mpm
.
get_module
().
replace_instruction
(
reduce
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
}
};
};
...
@@ -266,7 +269,7 @@ struct find_reduce_pointwise
...
@@ -266,7 +269,7 @@ struct find_reduce_pointwise
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
auto
out
=
insert_ins_in_submodule
(
rm
,
pw
,
map_ins
);
rm
->
replace_return
(
out
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
mpm
.
get_module
().
replace_instruction
(
pw
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
}
};
};
...
@@ -300,7 +303,7 @@ struct find_reduce_reduce
...
@@ -300,7 +303,7 @@ struct find_reduce_reduce
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
auto
out
=
insert_module_in_submodule
(
rm
,
reduce1
,
map_ins
);
rm
->
replace_return
(
out
);
rm
->
replace_return
(
out
);
auto
new_inputs
=
find_inputs
(
rm
,
map_ins
);
auto
new_inputs
=
find_inputs
(
rm
,
mpm
.
get_module
(),
map_ins
);
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
mpm
.
get_module
().
replace_instruction
(
reduce1
,
reduce1
->
get_operator
(),
new_inputs
,
{
rm
});
}
}
};
};
...
...
src/include/migraphx/cpp_generator.hpp
View file @
399eacef
...
@@ -77,7 +77,7 @@ struct cpp_generator
...
@@ -77,7 +77,7 @@ struct cpp_generator
function
&
set_types
(
const
module
&
m
);
function
&
set_types
(
const
module
&
m
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_types
(
const
module
&
m
,
const
std
::
function
<
std
::
string
(
shape
)
>&
parse
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
set_generic_types
(
const
module
&
m
);
function
&
add_generic_param
(
const
std
::
string
&
name
);
function
&
add_generic_param
(
const
std
::
string
&
p
name
);
};
};
cpp_generator
();
cpp_generator
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
399eacef
...
@@ -180,7 +180,7 @@ struct index
...
@@ -180,7 +180,7 @@ struct index
}
}
else
else
{
{
static_assert
(
max_stride_iterations
(
n
,
stride
)
<
64
);
//
static_assert(max_stride_iterations(n, stride) < 64);
sequence
(
max_stride_iterations
(
n
,
stride
),
[
&
](
auto
...
ks
)
{
sequence
(
max_stride_iterations
(
n
,
stride
),
[
&
](
auto
...
ks
)
{
fold
([
&
](
auto
d
,
auto
k
)
{
fold
([
&
](
auto
d
,
auto
k
)
{
auto
i
=
start
+
stride
*
k
;
auto
i
=
start
+
stride
*
k
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment