Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7668ef6b
Commit
7668ef6b
authored
Nov 20, 2023
by
Paul
Browse files
Format
parent
5c4e15f2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+11
-11
No files found.
src/targets/gpu/jit/concat.cpp
View file @
7668ef6b
...
@@ -86,7 +86,7 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -86,7 +86,7 @@ struct concat_compiler : compiler<concat_compiler>
{
{
const
auto
&
name
=
op_names
[
i
];
const
auto
&
name
=
op_names
[
i
];
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
prefix
=
to_c_id
(
name
+
std
::
to_string
(
i
)
+
"_concat_x"
);
auto
prefix
=
to_c_id
(
name
+
std
::
to_string
(
i
)
+
"_concat_x"
);
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
j
)
{
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
j
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
j
);
return
"auto "
+
prefix
+
std
::
to_string
(
j
);
});
});
...
@@ -112,7 +112,7 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -112,7 +112,7 @@ struct concat_compiler : compiler<concat_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
if
(
op
.
name
()
==
"fused_concat"
)
if
(
op
.
name
()
==
"fused_concat"
)
{
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
transform
(
range
(
ins
->
module_inputs
().
size
()),
transform
(
range
(
ins
->
module_inputs
().
size
()),
...
@@ -134,7 +134,7 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -134,7 +134,7 @@ struct concat_compiler : compiler<concat_compiler>
ins
->
module_inputs
().
end
()
-
1
,
ins
->
module_inputs
().
end
()
-
1
,
std
::
back_inserter
(
mod_names
),
std
::
back_inserter
(
mod_names
),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
v
[
"ops"
]
=
mod_names
;
v
[
"ops"
]
=
mod_names
;
module_ref
last_mod
=
ins
->
module_inputs
().
back
();
module_ref
last_mod
=
ins
->
module_inputs
().
back
();
v
[
"post"
]
=
"MIGRAPHX_LIFT("
+
mod_names_lookup
.
at
(
last_mod
->
name
())
+
")"
;
v
[
"post"
]
=
"MIGRAPHX_LIFT("
+
mod_names_lookup
.
at
(
last_mod
->
name
())
+
")"
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
...
@@ -145,7 +145,7 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -145,7 +145,7 @@ struct concat_compiler : compiler<concat_compiler>
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
});
});
v
[
"args"
]
=
mod_args
;
v
[
"args"
]
=
mod_args
;
auto
prefix_name
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
auto
prefix_name
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
ins
->
module_inputs
().
end
()
-
1
,
std
::
string
{},
std
::
string
{},
...
@@ -159,21 +159,21 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -159,21 +159,21 @@ struct concat_compiler : compiler<concat_compiler>
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
v
[
"kernel"
]
=
prefix_name
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
}
}
else
if
(
op
.
name
()
==
"concat"
)
else
if
(
op
.
name
()
==
"concat"
)
{
{
auto
concat_inputs
=
ins
->
inputs
().
size
()
-
1
;
auto
concat_inputs
=
ins
->
inputs
().
size
()
-
1
;
if
(
not
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
concat_inputs
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
concat_inputs
=
ins
->
inputs
().
size
()
-
pm
->
get_parameter_names
().
size
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_concat"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_concat)"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
v
[
"kernel"
]
=
"concat_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
}
std
::
vector
<
std
::
string
>
mod_names
(
concat_inputs
,
"op::id{}"
);
std
::
vector
<
std
::
string
>
mod_names
(
concat_inputs
,
"op::id{}"
);
v
[
"ops"
]
=
mod_names
;
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
=
{{
"op::id{}"
,
1
}};
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
=
{{
"op::id{}"
,
1
}};
v
[
"args"
]
=
mod_args
;
v
[
"args"
]
=
mod_args
;
}
}
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment