Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cd4ab535
Commit
cd4ab535
authored
Jun 20, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
3891ee58
a0fa3742
Changes
279
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1009 additions
and
162 deletions
+1009
-162
src/shape.cpp
src/shape.cpp
+93
-51
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+290
-19
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+165
-0
src/targets/cpu/include/migraphx/cpu/parallel.hpp
src/targets/cpu/include/migraphx/cpu/parallel.hpp
+1
-1
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+0
-1
src/targets/fpga/include/migraphx/fpga/vitis_ai_adapter.hpp
src/targets/fpga/include/migraphx/fpga/vitis_ai_adapter.hpp
+1
-1
src/targets/fpga/subgraph.cpp
src/targets/fpga/subgraph.cpp
+1
-2
src/targets/fpga/vitis_ai_adapter.cpp
src/targets/fpga/vitis_ai_adapter.cpp
+1
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+12
-3
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+150
-5
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+62
-15
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+10
-5
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+157
-11
src/targets/gpu/compiler.cpp
src/targets/gpu/compiler.cpp
+23
-12
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+10
-10
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+12
-11
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+16
-12
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+3
-1
src/targets/gpu/driver/CMakeLists.txt
src/targets/gpu/driver/CMakeLists.txt
+1
-0
No files found.
src/shape.cpp
View file @
cd4ab535
...
@@ -74,13 +74,23 @@ struct shape_impl
...
@@ -74,13 +74,23 @@ struct shape_impl
shape_impl
(
shape
::
type_t
t
,
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opt
s
)
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt
imals_list
)
:
m_type
(
t
)
:
m_type
(
t
)
{
{
assert
(
mins
.
size
()
==
maxes
.
size
()
and
maxes
.
size
()
==
opts
.
size
());
if
(
optimals_list
.
empty
())
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
{
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
],
opts
[
i
]});
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
]});
}
}
else
{
assert
(
mins
.
size
()
==
maxes
.
size
()
and
maxes
.
size
()
==
optimals_list
.
size
());
for
(
size_t
i
=
0
;
i
<
mins
.
size
();
++
i
)
{
m_dyn_dims
.
push_back
(
shape
::
dynamic_dimension
{
mins
[
i
],
maxes
[
i
],
optimals_list
[
i
]});
}
}
}
}
}
...
@@ -147,7 +157,7 @@ struct shape_impl
...
@@ -147,7 +157,7 @@ struct shape_impl
std
::
transform
(
m_dyn_dims
.
cbegin
(),
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
min
;
});
[](
const
shape
::
dynamic_dimension
&
x
)
{
return
x
.
min
;
});
return
ret
;
return
ret
;
}
}
...
@@ -157,19 +167,20 @@ struct shape_impl
...
@@ -157,19 +167,20 @@ struct shape_impl
std
::
transform
(
m_dyn_dims
.
cbegin
(),
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
max
;
});
[](
const
shape
::
dynamic_dimension
&
x
)
{
return
x
.
max
;
});
return
ret
;
return
ret
;
}
}
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt_lens
()
const
{
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
ret
(
m_dyn_dims
.
size
());
std
::
transform
(
m_dyn_dims
.
cbegin
(),
std
::
transform
(
m_dyn_dims
.
cbegin
(),
m_dyn_dims
.
cend
(),
m_dyn_dims
.
cend
(),
ret
.
begin
(),
ret
.
begin
(),
[](
shape
::
dynamic_dimension
x
)
{
return
x
.
opt
;
});
[](
const
shape
::
dynamic_dimension
&
x
)
{
return
x
.
opt
imals
;
});
return
ret
;
return
ret
;
}
}
// Does the shape skip over elements?
// Does the shape skip over elements?
bool
skips
()
const
bool
skips
()
const
{
{
...
@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
...
@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
shape
::
shape
(
type_t
t
,
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opts
)
std
::
vector
<
std
::
set
<
std
::
size_t
>>
optimals_list
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
mins
),
std
::
move
(
maxes
),
std
::
move
(
opts
)))
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
mins
),
std
::
move
(
maxes
),
std
::
move
(
optimals_list
)))
{
{
}
}
...
@@ -349,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
...
@@ -349,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
)
const
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
dx
)
const
{
{
assert
(
this
->
standard
());
assert
(
idx
<
elements
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
multi_copy
(
i
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
multi_copy
(
idx
,
indices
.
data
(),
indices
.
data
()
+
lens
().
size
());
return
indices
;
return
indices
;
}
}
void
shape
::
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
void
shape
::
multi_copy
(
std
::
size_t
i
dx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
{
{
assert
(
this
->
standard
())
;
size_t
tidx
=
idx
;
(
void
)
end
;
(
void
)
end
;
assert
(
idx
<
elements
());
assert
(
lens
().
size
()
<=
(
end
-
start
));
assert
(
lens
().
size
()
<=
(
end
-
start
));
std
::
transform
(
strides
().
begin
(),
for
(
size_t
ii
=
lens
().
size
()
-
1
;
ii
>
0
;
ii
--
)
strides
().
end
(),
{
lens
().
begin
(),
*
(
start
+
ii
)
=
tidx
%
lens
()[
ii
];
start
,
tidx
=
tidx
/
lens
()[
ii
];
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
}
assert
(
len
>
0
and
stride
>
0
);
*
start
=
tidx
;
return
(
i
/
stride
)
%
len
;
});
}
}
bool
shape
::
packed
()
const
bool
shape
::
packed
()
const
...
@@ -469,12 +478,44 @@ shape shape::with_type(type_t t) const
...
@@ -469,12 +478,44 @@ shape shape::with_type(type_t t) const
shape
shape
::
to_dynamic
()
const
shape
shape
::
to_dynamic
()
const
{
{
if
(
not
sub_shapes
().
empty
())
{
std
::
vector
<
shape
>
subs
;
std
::
transform
(
sub_shapes
().
cbegin
(),
sub_shapes
().
cend
(),
std
::
back_inserter
(
subs
),
[](
auto
s
)
{
return
s
.
to_dynamic
();
});
return
{
subs
};
}
if
(
this
->
dynamic
())
if
(
this
->
dynamic
())
{
{
return
*
this
;
return
*
this
;
}
}
std
::
vector
<
std
::
size_t
>
zeroes
(
this
->
ndim
(),
0
);
return
{
type
(),
lens
(),
lens
(),
{}};
return
{
type
(),
lens
(),
lens
(),
zeroes
};
}
shape
shape
::
to_static
(
std
::
size_t
x
)
const
{
if
(
not
sub_shapes
().
empty
())
{
std
::
vector
<
shape
>
subs
;
std
::
transform
(
sub_shapes
().
cbegin
(),
sub_shapes
().
cend
(),
std
::
back_inserter
(
subs
),
[
&
](
auto
s
)
{
return
s
.
to_static
(
x
);
});
return
{
subs
};
}
if
(
not
this
->
dynamic
())
{
return
*
this
;
}
auto
static_lens
=
this
->
max_lens
();
std
::
transform
(
static_lens
.
begin
(),
static_lens
.
end
(),
this
->
dyn_dims
().
cbegin
(),
static_lens
.
begin
(),
[
&
](
auto
sl
,
auto
dd
)
{
return
dd
.
is_fixed
()
?
sl
:
x
;
});
return
{
type
(),
static_lens
};
}
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
...
@@ -506,23 +547,22 @@ std::vector<std::size_t> shape::max_lens() const
...
@@ -506,23 +547,22 @@ std::vector<std::size_t> shape::max_lens() const
return
this
->
dynamic
()
?
impl
->
max_lens
()
:
this
->
lens
();
return
this
->
dynamic
()
?
impl
->
max_lens
()
:
this
->
lens
();
}
}
std
::
vector
<
std
::
size_t
>
shape
::
opt_lens
()
const
std
::
vector
<
std
::
set
<
std
::
size_t
>>
shape
::
opt_lens
()
const
{
return
impl
->
opt_lens
();
}
{
return
this
->
dynamic
()
?
impl
->
opt_lens
()
:
this
->
lens
();
}
bool
shape
::
dynamic_dimension
::
is_fixed
()
const
{
return
this
->
min
==
this
->
max
;
}
bool
shape
::
dynamic_dimension
::
is_fixed
()
const
{
return
this
->
min
==
this
->
max
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
not
optimals
.
empty
()
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
+=
(
const
std
::
size_t
&
x
)
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
+=
(
const
std
::
size_t
&
x
)
{
{
this
->
min
+=
x
;
this
->
min
+=
x
;
this
->
max
+=
x
;
this
->
max
+=
x
;
if
(
this
->
opt
!=
0
)
std
::
set
<
std
::
size_t
>
new_optimals
;
{
std
::
transform
(
this
->
optimals
.
begin
(),
this
->
opt
+=
x
;
this
->
optimals
.
end
(),
};
std
::
inserter
(
new_optimals
,
new_optimals
.
begin
()),
[
&
x
](
const
auto
&
opt
)
{
return
(
opt
+
x
);
});
this
->
optimals
=
new_optimals
;
return
*
this
;
return
*
this
;
}
}
...
@@ -532,19 +572,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
...
@@ -532,19 +572,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
assert
(
this
->
max
>=
x
);
assert
(
this
->
max
>=
x
);
this
->
min
-=
x
;
this
->
min
-=
x
;
this
->
max
-=
x
;
this
->
max
-=
x
;
if
(
this
->
opt
!=
0
)
std
::
set
<
std
::
size_t
>
new_optimals
;
{
std
::
transform
(
this
->
optimals
.
begin
(),
assert
(
this
->
opt
>=
x
);
this
->
optimals
.
end
(),
this
->
opt
-=
x
;
std
::
inserter
(
new_optimals
,
new_optimals
.
begin
()),
}
[
&
x
](
const
auto
&
opt
)
{
assert
(
opt
>=
x
);
return
(
opt
-
x
);
});
this
->
optimals
=
new_optimals
;
return
*
this
;
return
*
this
;
}
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
// don't check opt if both are fixed
// don't check opt
imals
if both are fixed
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
((
x
.
is_fixed
()
and
y
.
is_fixed
())
or
(
x
.
opt
==
y
.
opt
)));
((
x
.
is_fixed
()
and
y
.
is_fixed
())
or
(
x
.
opt
imals
==
y
.
opt
imals
)));
}
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
...
@@ -553,7 +597,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
...
@@ -553,7 +597,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
{
{
os
<<
"["
<<
x
.
min
<<
", "
<<
x
.
max
<<
", "
<<
x
.
opt
<<
"]"
;
os
<<
"[
"
<<
x
.
min
<<
", "
<<
x
.
max
<<
",
{
"
<<
migraphx
::
to_string_range
(
x
.
optimals
)
<<
"
}
]"
;
return
os
;
return
os
;
}
}
...
@@ -662,12 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
...
@@ -662,12 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{
{
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
auto
v_dd
=
v
.
at
(
"dynamic_dimensions"
);
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
(
v
.
at
(
"dynamic_dimensions"
).
size
());
std
::
transform
(
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
migraphx
::
value
x
)
{
std
::
transform
(
auto
x_min
=
x
.
at
(
"min"
).
template
to
<
size_t
>();
v_dd
.
begin
(),
v_dd
.
end
(),
dyn_dims
.
begin
(),
[](
const
migraphx
::
value
&
x
)
{
auto
x_max
=
x
.
at
(
"max"
).
template
to
<
size_t
>();
return
from_value
<
shape
::
dynamic_dimension
>
(
x
);
auto
x_opt
=
x
.
at
(
"opt"
).
template
to
<
size_t
>();
});
return
shape
::
dynamic_dimension
{
x_min
,
x_max
,
x_opt
};
});
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
s
=
shape
{
shape
::
parse_type
(
t
),
dyn_dims
};
}
}
...
...
src/simplify_algebra.cpp
View file @
cd4ab535
...
@@ -39,6 +39,8 @@
...
@@ -39,6 +39,8 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <unordered_set>
#include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES
)
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -52,8 +54,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
...
@@ -52,8 +54,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto
conv_const_weights
()
auto
conv_const_weights
()
{
{
return
match
::
name
(
"convolution"
)(
match
::
used_once
(),
return
match
::
name
(
"convolution"
)(
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
)));
match
::
used_once
(),
match
::
args
(
match
::
none_of
(
match
::
is_constant
()),
match
::
is_constant
().
bind
(
"w"
)));
}
}
auto
reduction
()
{
return
match
::
name_contains
(
"reduce"
);
}
auto
reduction
()
{
return
match
::
name_contains
(
"reduce"
);
}
...
@@ -203,7 +206,137 @@ struct find_mul_slice_conv
...
@@ -203,7 +206,137 @@ struct find_mul_slice_conv
}
}
};
};
// a * (x + b) => a * x + a * b
struct
find_mul_dot
{
auto
matcher
()
const
{
auto
is_dot_const_inputs
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
is_constant
()));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
is_dot_const_inputs
.
bind
(
"dot"
),
match
::
name
(
"broadcast"
,
"multibroadcast"
).
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
dot_ins
=
r
.
instructions
[
"dot"
];
auto
a_ins
=
dot_ins
->
inputs
()[
0
];
auto
b_ins
=
dot_ins
->
inputs
()[
1
];
auto
c_ins
=
r
.
instructions
[
"c"
];
const
auto
&
c_strides
=
c_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
c_strides
.
begin
(),
c_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
auto
add_mul_const
=
[
&
](
instruction_ref
x_ins
)
{
if
(
not
x_ins
->
can_eval
())
return
m
.
end
();
auto
broadcast_v
=
c_ins
->
get_operator
().
to_value
();
broadcast_v
[
"out_lens"
]
=
x_ins
->
get_shape
().
lens
();
auto
cb_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
c_ins
->
name
(),
broadcast_v
),
c_ins
->
inputs
());
return
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x_ins
,
cb_ins
);
};
if
(
c_strides
.
back
()
==
1
)
{
b_ins
=
add_mul_const
(
b_ins
);
}
else
if
(
c_strides
[
c_strides
.
size
()
-
2
]
==
1
)
{
a_ins
=
add_mul_const
(
a_ins
);
}
else
if
(
c_ins
->
get_shape
().
scalar
())
{
if
(
a_ins
->
can_eval
())
a_ins
=
add_mul_const
(
a_ins
);
else
b_ins
=
add_mul_const
(
b_ins
);
}
else
{
return
;
}
if
(
contains
({
a_ins
,
b_ins
},
m
.
end
()))
return
;
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
struct
find_dot_mul
{
auto
matcher
()
const
{
auto
const_broadcast
=
match
::
name
(
"broadcast"
,
"multibroadcast"
)(
match
::
is_constant
());
auto
mul
=
match
::
name
(
"mul"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
const_broadcast
.
bind
(
"d"
),
match
::
none_of
(
match
::
is_constant
()).
bind
(
"z"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
mul
,
match
::
is_constant
().
bind
(
"c"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
b_ins
=
ins
->
inputs
()[
1
];
auto
d_ins
=
r
.
instructions
[
"d"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
z_ins
=
r
.
instructions
[
"z"
];
const
auto
&
d_strides
=
d_ins
->
get_shape
().
strides
();
// There should only be one stride that is not zero
if
(
std
::
count_if
(
d_strides
.
begin
(),
d_strides
.
end
(),
[](
auto
s
)
{
return
s
!=
0
;
})
>
1
)
return
;
if
(
not
d_ins
->
get_shape
().
scalar
())
{
if
(
d_strides
.
back
()
==
1
and
not
b_ins
->
can_eval
())
return
;
if
(
d_strides
[
d_strides
.
size
()
-
2
]
==
1
and
not
a_ins
->
can_eval
())
return
;
}
auto
broadcast_v
=
d_ins
->
get_operator
().
to_value
();
auto
c_lens
=
c_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
permutation
(
c_lens
.
size
());
std
::
iota
(
permutation
.
begin
(),
permutation
.
end
(),
0
);
std
::
swap
(
permutation
.
back
(),
permutation
[
permutation
.
size
()
-
2
]);
c_lens
=
reorder_dims
(
c_lens
,
permutation
);
broadcast_v
[
"out_lens"
]
=
c_lens
;
auto
db_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
d_ins
->
name
(),
broadcast_v
),
d_ins
->
inputs
());
auto
db_transpose_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
db_ins
);
auto
cd_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
c_ins
,
db_transpose_ins
);
if
(
c_ins
==
b_ins
)
{
a_ins
=
z_ins
;
b_ins
=
cd_ins
;
}
else
{
a_ins
=
cd_ins
;
b_ins
=
z_ins
;
}
m
.
replace_instruction
(
ins
,
make_op
(
"dot"
),
a_ins
,
b_ins
);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct
find_mul_add
struct
find_mul_add
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -268,6 +401,32 @@ struct find_dot_add
...
@@ -268,6 +401,32 @@ struct find_dot_add
}
}
};
};
struct
find_conv_add
{
auto
matcher
()
const
{
auto
add
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
match
::
is_constant
()).
bind
(
"a"
)),
match
::
used_once
());
return
match
::
name
(
"convolution"
)(
match
::
used_once
(),
match
::
args
(
add
,
match
::
is_constant
().
bind
(
"w"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
w_ins
=
r
.
instructions
[
"w"
];
auto
conv1
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
a_ins
,
w_ins
);
auto
conv2
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
,
w_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
conv1
,
conv2
);
}
};
struct
find_add_lit_broadcast
struct
find_add_lit_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -329,30 +488,123 @@ struct find_inner_broadcast
...
@@ -329,30 +488,123 @@ struct find_inner_broadcast
{
{
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
static
auto
non_scalar_op
(
const
std
::
string
&
name
)
{
return
[
=
](
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
scalar
())
return
false
;
return
ins
->
name
()
==
name
;
};
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
broadcasts
=
ins
->
inputs
();
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
if
(
broadcasts
.
empty
())
return
;
return
;
// Skip if different data types are used
if
(
any_of
(
broadcasts
,
[
&
](
auto
i
)
{
return
i
->
get_shape
().
type
()
!=
broadcasts
.
front
()
->
get_shape
().
type
();
}))
return
;
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if
(
mixed_broadcasts
and
any_of
(
broadcasts
,
[
&
](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
false
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
false
;
auto
input
=
i
->
inputs
().
at
(
0
);
const
auto
&
lens
=
input
->
get_shape
().
lens
();
return
std
::
count_if
(
lens
.
begin
(),
lens
.
end
(),
[
&
](
std
::
size_t
d
)
{
return
d
==
1
;
})
<
(
lens
.
size
()
-
1
);
}))
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
[
&
](
instruction_ref
i
)
{
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
auto
input
=
i
->
inputs
().
front
();
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
()
and
if
(
mixed_broadcasts
and
not
i
->
get_shape
().
scalar
()
and
i
->
get_shape
().
elements
()
!=
1
;
i
->
get_shape
().
lens
().
size
()
>
1
)
}))
return
m
.
insert_instruction
(
i
,
make_op
(
"squeeze"
),
input
);
return
;
return
input
;
});
auto
b_it
=
std
::
find_if
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
[
&
](
auto
i
)
{
return
not
i
->
get_shape
().
scalar
();
std
::
sort
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
by
(
std
::
less
<>
{},
[](
instruction_ref
i
)
{
});
if
(
i
->
get_shape
().
scalar
())
if
(
b_it
==
broadcasts
.
end
())
return
2
;
b_it
=
broadcasts
.
begin
();
else
if
(
i
->
name
()
==
"broadcast"
)
return
0
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
1
;
return
3
;
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
(
*
b_it
)
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
struct
find_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a
=
ins
->
inputs
()[
0
];
auto
b
=
ins
->
inputs
()[
1
];
if
(
a
->
get_operator
().
name
()
!=
b
->
get_operator
().
name
())
return
;
if
(
ins
->
get_shape
().
lens
().
size
()
<
3
)
return
;
auto
nbatch_axes
=
ins
->
get_shape
().
lens
().
size
()
-
2
;
const
auto
&
a_strides
=
a
->
get_shape
().
strides
();
const
auto
&
b_strides
=
b
->
get_shape
().
strides
();
// Find leading batch axes that are broadcasted
auto
p
=
std
::
mismatch
(
a_strides
.
begin
(),
a_strides
.
begin
()
+
nbatch_axes
,
b_strides
.
begin
(),
b_strides
.
begin
()
+
nbatch_axes
,
[](
auto
astride
,
auto
bstride
)
{
return
astride
==
0
and
bstride
==
0
;
});
auto
naxes
=
p
.
first
-
a_strides
.
begin
();
assert
(
naxes
<=
nbatch_axes
);
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
auto
insert_broadcast
=
[
&
](
instruction_ref
b_ins
)
->
instruction_ref
{
auto
input
=
b_ins
->
inputs
()[
0
];
std
::
vector
<
std
::
size_t
>
lens
(
b_ins
->
get_shape
().
lens
().
begin
()
+
naxes
,
b_ins
->
get_shape
().
lens
().
end
());
if
(
b_ins
->
name
()
==
"multibroadcast"
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
else
if
(
b_ins
->
name
()
==
"broadcast"
)
{
auto
v
=
b_ins
->
get_operator
().
to_value
();
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
()
-
naxes
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
lens
}}),
input
);
}
assert
(
false
);
return
m
.
end
();
};
auto
a1
=
insert_broadcast
(
a
);
auto
b1
=
insert_broadcast
(
b
);
auto
dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
a1
,
b1
);
auto
broadcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
dot
);
m
.
replace_instruction
(
ins
,
broadcast
);
}
}
};
};
...
@@ -361,7 +613,8 @@ struct find_concat_op
...
@@ -361,7 +613,8 @@ struct find_concat_op
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
)),
match
::
used_once
()));
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)),
match
::
used_once
()));
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
...
@@ -380,7 +633,8 @@ struct find_concat_op
...
@@ -380,7 +633,8 @@ struct find_concat_op
static
bool
is_valid_op
(
const
operation
&
op
)
static
bool
is_valid_op
(
const
operation
&
op
)
{
{
return
op
.
name
()
==
"broadcast"
or
op
.
attributes
().
contains
(
"pointwise"
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
())
or
op
.
attributes
().
contains
(
"pointwise"
);
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
@@ -408,6 +662,16 @@ struct find_concat_op
...
@@ -408,6 +662,16 @@ struct find_concat_op
op
=
b
;
op
=
b
;
iaxis
=
0
;
iaxis
=
0
;
}
}
else
if
(
op
.
name
()
==
"multibroadcast"
)
{
shape
bshape
=
(
*
start
)
->
get_shape
();
auto
input
=
(
*
start
)
->
inputs
()[
0
];
if
(
iaxis
>=
bshape
.
strides
().
size
()
or
bshape
.
strides
()[
iaxis
]
==
0
)
return
{
start
,
last
};
op
.
from_value
({{
"out_lens"
,
get_output_lens
(
start
,
last
,
iaxis
)}});
auto
delta
=
bshape
.
lens
().
size
()
-
input
->
get_shape
().
lens
().
size
();
iaxis
-=
delta
;
}
std
::
vector
<
instruction_ref
>
concats
;
std
::
vector
<
instruction_ref
>
concats
;
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
...
@@ -1223,22 +1487,29 @@ struct find_split_transpose
...
@@ -1223,22 +1487,29 @@ struct find_split_transpose
void
simplify_algebra
::
apply
(
module
&
m
)
const
void
simplify_algebra
::
apply
(
module
&
m
)
const
{
{
size_t
trace
=
value_of
(
MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES
{});
// Run simplifications multiple times
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
trace
,
m
,
find_inner_broadcast
{},
find_inner_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
find_add_convs
{},
find_conv_dot_horiz_fusion
{},
find_conv_dot_horiz_fusion
{},
find_mul_conv
{},
find_mul_conv
{},
find_mul_slice_conv
{},
find_mul_slice_conv
{},
find_mul_dot
{},
find_dot_mul
{},
find_mul_add
{},
find_mul_add
{},
find_unit_ops
{},
find_unit_ops
{},
find_neg_unit_ops
{},
find_neg_unit_ops
{},
find_zero_ops
{},
find_zero_ops
{},
find_dot_add
{},
find_dot_add
{},
find_conv_add
{},
find_div_const
{},
find_div_const
{},
find_sub_const
{},
find_sub_const
{},
find_rsqrt
{},
find_rsqrt
{},
...
...
src/simplify_reshapes.cpp
View file @
cd4ab535
...
@@ -762,7 +762,7 @@ struct find_transpose_slice
...
@@ -762,7 +762,7 @@ struct find_transpose_slice
return
;
return
;
// Compute axis before transpose to use for unsqueeze
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
()
;
auto
preaxis
=
perm
[
axis
]
;
// Make unsqueeze
// Make unsqueeze
std
::
vector
<
int64_t
>
steps
(
sdistance
.
size
());
std
::
vector
<
int64_t
>
steps
(
sdistance
.
size
());
std
::
transform
(
std
::
transform
(
...
...
src/split_single_dyn_dim.cpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
dynamic_dimensions_check
{
std
::
string
dyn_param_str
;
size_t
dyn_index
;
size_t
min_dim
;
size_t
max_dim
;
};
optional
<
dynamic_dimensions_check
>
has_one_dyn_dim
(
const
std
::
unordered_map
<
std
::
string
,
shape
>&
param_shapes
)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
auto
is_dynamic
=
[](
const
auto
&
p
)
{
return
p
.
second
.
dynamic
();
};
auto
ps_it
=
std
::
find_if
(
param_shapes
.
begin
(),
param_shapes
.
end
(),
is_dynamic
);
if
(
ps_it
==
param_shapes
.
end
())
return
std
::
nullopt
;
// Check if there is a second dynamic parameter
if
(
std
::
any_of
(
std
::
next
(
ps_it
),
param_shapes
.
end
(),
is_dynamic
))
return
std
::
nullopt
;
const
auto
&
dds
=
ps_it
->
second
.
dyn_dims
();
auto
is_non_fixed
=
[](
const
auto
&
dd
)
{
return
not
dd
.
is_fixed
();
};
auto
dds_it
=
std
::
find_if
(
dds
.
begin
(),
dds
.
end
(),
is_non_fixed
);
if
(
dds_it
==
dds
.
end
())
return
std
::
nullopt
;
// Check if there is a second non-fixed dynamic_dimension
if
(
std
::
any_of
(
std
::
next
(
dds_it
),
dds
.
end
(),
is_non_fixed
))
return
std
::
nullopt
;
return
dynamic_dimensions_check
{
ps_it
->
first
,
static_cast
<
std
::
size_t
>
(
std
::
distance
(
dds
.
begin
(),
dds_it
)),
dds_it
->
min
,
dds_it
->
max
};
}
namespace
{
struct
find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
}
// namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/
void
split_single_dyn_dim
::
apply
(
module_pass_manager
&
mpm
)
const
{
module_ref
mm
=
&
mpm
.
get_module
();
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
optional
<
dynamic_dimensions_check
>
dd_check
=
has_one_dyn_dim
(
param_shapes
);
auto
any_sm_next
=
[
&
](
auto
ddc
)
{
auto
p_outputs
=
mm
->
get_parameter
(
ddc
->
dyn_param_str
)
->
outputs
();
return
std
::
any_of
(
p_outputs
.
cbegin
(),
p_outputs
.
cend
(),
[](
auto
ins
)
{
return
ins
->
name
()
==
"select_module"
;
});
};
if
(
dd_check
.
has_value
()
and
not
any_sm_next
(
dd_check
))
{
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
->
dyn_param_str
);
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
->
dyn_param_str
);
std
::
vector
<
module_ref
>
submodules
;
// create submodules for each dimension size
for
(
size_t
dim_size
:
migraphx
::
range
(
dd_check
->
min_dim
,
dd_check
->
max_dim
+
1
))
{
auto
*
submod
=
mpm
.
create_module
(
"dim_"
+
std
::
to_string
(
dim_size
));
// instruction map for new static shaped submodule parameters
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// create static shape using dim_size
auto
static_lens
=
dyn_param_shape
.
max_lens
();
static_lens
.
at
(
dd_check
->
dyn_index
)
=
dim_size
;
map_ins
[
dyn_param
]
=
submod
->
add_parameter
(
dd_check
->
dyn_param_str
,
migraphx
::
shape
{
dyn_param_shape
.
type
(),
static_lens
});
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
submod
->
add_return
({
outputs
});
match
::
find_matches
(
*
submod
,
find_static_2in_broadcasts
{});
submodules
.
push_back
(
submod
);
}
// redirect to select_module operator and return
std
::
vector
<
instruction_ref
>
sm_inputs
;
std
::
transform
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
sm_inputs
),
[
&
](
auto
pn
)
{
return
mm
->
get_parameter
(
pn
);
});
auto
output_shapes
=
mm
->
get_output_shapes
();
migraphx
::
shape
out_attr
=
migraphx
::
shape
{
output_shapes
};
auto
sm_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"select_module"
,
{{
"output_dyn_shapes"
,
migraphx
::
to_value
(
out_attr
)}}),
sm_inputs
,
submodules
);
std
::
vector
<
instruction_ref
>
outputs
(
output_shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
output_shapes
.
size
();
++
i
)
{
outputs
.
at
(
i
)
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
i
}}),
sm_ins
);
}
mm
->
replace_return
(
outputs
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/cpu/include/migraphx/cpu/parallel.hpp
View file @
cd4ab535
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
// #define MIGRAPHX_DISABLE_OMP
// #define MIGRAPHX_DISABLE_OMP
#include <cmath>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#ifdef MIGRAPHX_DISABLE_OMP
#ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
...
...
src/targets/cpu/target.cpp
View file @
cd4ab535
...
@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_algebra
{},
simplify_reshapes
{},
simplify_reshapes
{},
layout_nhwc
{},
dead_code_elimination
{},
dead_code_elimination
{},
simplify_reshapes
{},
simplify_reshapes
{},
simplify_algebra
{},
simplify_algebra
{},
...
...
src/targets/fpga/include/migraphx/fpga/vitis_ai_adapter.hpp
View file @
cd4ab535
...
@@ -41,7 +41,7 @@ class x_model
...
@@ -41,7 +41,7 @@ class x_model
void
set_shape
(
migraphx
::
shape
);
void
set_shape
(
migraphx
::
shape
);
};
};
x_model
create_xmodel
(
migraphx
::
module_ref
mod
);
x_model
create_xmodel
(
migraphx
::
const_
module_ref
mod
);
migraphx
::
argument
execute
(
const
x_model
&
xmodel
,
migraphx
::
argument
execute
(
const
x_model
&
xmodel
,
const
migraphx
::
shape
&
output_shape
,
const
migraphx
::
shape
&
output_shape
,
...
...
src/targets/fpga/subgraph.cpp
View file @
cd4ab535
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
...
@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// assuming all FPGA instructions are in one contiguous range
// assuming all FPGA instructions are in one contiguous range
pm
->
insert_instructions
(
pm
->
end
(),
first
,
last
,
{});
pm
->
insert_instructions
(
pm
->
end
(),
first
,
std
::
next
(
last
),
{});
migraphx
::
instruction_ref
placeholder_ins
;
migraphx
::
instruction_ref
placeholder_ins
;
for
(
auto
it
:
iterator_for
(
mod
))
for
(
auto
it
:
iterator_for
(
mod
))
{
{
...
...
src/targets/fpga/vitis_ai_adapter.cpp
View file @
cd4ab535
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
x_model
create_xmodel
(
const
migraphx
::
module_ref
mod
)
x_model
create_xmodel
(
migraphx
::
const_
module_ref
mod
)
{
{
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
x_model
xmodel
;
x_model
xmodel
;
...
...
src/targets/gpu/CMakeLists.txt
View file @
cd4ab535
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
# THE SOFTWARE.
# THE SOFTWARE.
# ####################################################################################
# ####################################################################################
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
/opt/rocm/hip
)
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
find_package
(
miopen
)
find_package
(
miopen
)
# rocblas
# rocblas
...
@@ -33,7 +33,13 @@ if(NOT TARGET MIOpen)
...
@@ -33,7 +33,13 @@ if(NOT TARGET MIOpen)
message
(
SEND_ERROR
"Cant find miopen"
)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
endif
()
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
find_package
(
composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED
)
if
(
BUILD_DEV
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
else
()
set
(
MIGRAPHX_USE_HIPRTC ON CACHE BOOL
"Use hipRTC APIs"
)
endif
()
include
(
Embed
)
include
(
Embed
)
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
...
@@ -91,6 +97,7 @@ add_library(migraphx_gpu
...
@@ -91,6 +97,7 @@ add_library(migraphx_gpu
compile_miopen.cpp
compile_miopen.cpp
compiler.cpp
compiler.cpp
device_name.cpp
device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp
fuse_mlir.cpp
fuse_ops.cpp
fuse_ops.cpp
gather.cpp
gather.cpp
...
@@ -119,6 +126,7 @@ add_library(migraphx_gpu
...
@@ -119,6 +126,7 @@ add_library(migraphx_gpu
schedule_model.cpp
schedule_model.cpp
sync_device.cpp
sync_device.cpp
target.cpp
target.cpp
time_op.cpp
topk.cpp
topk.cpp
write_literals.cpp
write_literals.cpp
${
JIT_GPU_SRCS
}
${
JIT_GPU_SRCS
}
...
@@ -237,9 +245,10 @@ else()
...
@@ -237,9 +245,10 @@ else()
endif
()
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
composable_kernel::jit_library
)
add_subdirectory
(
driver
)
add_subdirectory
(
driver
)
add_subdirectory
(
hiprtc
)
rocm_install_targets
(
rocm_install_targets
(
TARGETS migraphx_gpu migraphx_device compile_for_gpu
TARGETS migraphx_gpu migraphx_device compile_for_gpu
...
...
src/targets/gpu/compile_gen.cpp
View file @
cd4ab535
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -168,10 +169,11 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -168,10 +169,11 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return
join_strings
(
std
::
move
(
transformers
),
", "
);
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
void
generate_pointwise
(
cpp_generator
&
gg
,
const
module
&
pm
,
const
std
::
string
&
name
)
{
{
module
m
=
pm
;
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
run_passes
(
m
,
{
rewrite_quantization
{},
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
...
@@ -184,8 +186,141 @@ std::string generate_pointwise(const module& pm, const std::string& name)
...
@@ -184,8 +186,141 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions
// Add explict conversions
g
.
fresult
(
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
create_function
(
gg
.
create_function
(
g
.
generate_module
(
m
)
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
));
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
})
.
set_generic_types
(
m
)
.
set_name
(
name
));
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
{
cpp_generator
g
;
generate_pointwise
(
g
,
pm
,
name
);
return
g
.
str
();
}
std
::
string
reduce_op
::
str
()
const
{
return
write
+
"(r.reduce("
+
reduction
+
", "
+
init
+
", "
+
read
+
")("
+
input
+
"))"
;
}
void
reduce_op
::
set
(
instruction_ref
ins
,
const
operation
&
op
)
{
if
(
op
.
name
()
==
"reduce_sum"
)
{
reduction
=
"op::sum{}"
;
}
else
if
(
op
.
name
()
==
"reduce_mean"
)
{
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
reduce_elements
=
s
.
elements
()
/
ins
->
get_shape
().
elements
();
auto
reduce_type
=
s
.
type
();
reduction
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean<"
+
std
::
to_string
(
reduce_elements
)
+
">{}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
read
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
else
if
(
contains
({
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
},
reduce_type
))
read
=
mean
;
else
write
=
mean
;
}
else
if
(
op
.
name
()
==
"reduce_max"
)
{
reduction
=
"op::max{}"
;
init
=
"lowest{}"
;
}
else
if
(
op
.
name
()
==
"reduce_min"
)
{
reduction
=
"op::min{}"
;
init
=
"highest{}"
;
}
else
if
(
op
.
name
()
==
"reduce_prod"
)
{
reduction
=
"op::product{}"
;
init
=
"1"
;
}
else
{
MIGRAPHX_THROW
(
"Unsupported reduce"
);
}
}
std
::
string
reduce_op
::
generate
(
instruction_ref
ins
,
const
std
::
string
&
x
)
{
reduce_op
r
{
x
};
r
.
set
(
ins
,
ins
->
get_operator
());
return
r
.
str
();
}
static
bool
use_lazy_inner
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
auto
output
=
ins
->
outputs
().
front
();
return
contains
(
output
->
name
(),
"reduce"
)
or
output
->
name
()
==
"@return"
;
}
std
::
string
generate_reduce
(
const
module
&
m
,
const
std
::
string
&
name
)
{
cpp_generator
g
;
auto
ilens
=
m
.
get_parameter_shapes
().
begin
()
->
second
.
lens
();
std
::
size_t
i
=
0
;
auto
f
=
g
.
generate_module
(
m
,
[
&
](
instruction_ref
ins
,
const
auto
&
names
)
{
if
(
contains
(
ins
->
name
(),
"reduce"
))
{
return
reduce_op
::
generate
(
ins
,
names
.
at
(
ins
->
inputs
().
front
()));
}
else
if
(
ins
->
name
()
==
"pointwise"
)
{
auto
pointwise_name
=
"pointwise"
+
std
::
to_string
(
i
);
i
++
;
generate_pointwise
(
g
,
*
ins
->
module_inputs
().
front
(),
pointwise_name
);
std
::
vector
<
instruction_ref
>
tensors
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
tensors
),
[
&
](
auto
input
)
{
return
input
->
get_shape
().
lens
()
==
ilens
and
not
input
->
get_shape
().
broadcasted
();
});
auto
inner_names
=
names
;
for
(
auto
input
:
ins
->
inputs
())
{
if
(
input
->
name
()
!=
"@param"
)
continue
;
if
(
contains
(
tensors
,
input
))
continue
;
inner_names
[
input
]
+=
"[out_idx]"
;
}
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
pointwise_name
+
"("
+
join_strings
(
cpp_generator
::
to_args
(
ins
->
inputs
(),
inner_names
),
", "
)
+
")"
;
if
(
tensors
.
empty
())
return
call_function
;
const
std
::
string
inner_template
=
"r.${inner}([=](${params}) { return ${call}; })(${args})"
;
std
::
string
inner_name
=
use_lazy_inner
(
ins
)
?
"lazy_inner"
:
"inner"
;
auto
args
=
cpp_generator
::
to_args
(
tensors
,
names
);
auto
params
=
cpp_generator
::
to_args
(
tensors
,
inner_names
);
std
::
transform
(
params
.
begin
(),
params
.
end
(),
params
.
begin
(),
[](
auto
s
)
{
return
"auto "
+
s
;
});
return
interpolate_string
(
inner_template
,
{{
"inner"
,
inner_name
},
{
"params"
,
join_strings
(
params
,
", "
)},
{
"args"
,
join_strings
(
args
,
", "
)},
{
"call"
,
call_function
}});
}
else
if
(
ins
->
name
()
==
"multibroadcast"
)
{
return
names
.
at
(
ins
->
inputs
().
front
());
}
MIGRAPHX_THROW
(
"Unknown operator: "
+
ins
->
name
());
});
f
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
}).
set_generic_types
(
m
).
set_name
(
name
);
f
.
add_generic_param
(
"r"
);
f
.
add_generic_param
(
"out_idx"
);
f
.
unused_param
(
"out_idx"
);
g
.
create_function
(
f
);
return
g
.
str
();
return
g
.
str
();
}
}
...
@@ -196,7 +331,17 @@ static std::vector<std::string> get_op_names(const module& m)
...
@@ -196,7 +331,17 @@ static std::vector<std::string> get_op_names(const module& m)
{
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
continue
;
result
.
push_back
(
ins
.
name
());
if
(
ins
.
name
()
==
"multibroadcast"
)
continue
;
if
(
ins
.
name
()
==
"pointwise"
)
{
auto
names
=
get_op_names
(
*
ins
.
module_inputs
().
front
());
result
.
insert
(
result
.
end
(),
names
.
begin
(),
names
.
end
());
}
else
{
result
.
push_back
(
ins
.
name
());
}
}
}
return
result
;
return
result
;
}
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
cd4ab535
...
@@ -32,6 +32,13 @@
...
@@ -32,6 +32,13 @@
#ifdef MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
#include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/value.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/file_buffer.hpp>
#else
#else
#include <migraphx/compile_src.hpp>
#include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp>
#include <migraphx/process.hpp>
...
@@ -49,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
...
@@ -49,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
);
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
{
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
...
@@ -63,6 +67,7 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
...
@@ -63,6 +67,7 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
throw
make_exception
(
ctx
,
hiprtc_error
(
err
,
msg
));
throw
make_exception
(
ctx
,
hiprtc_error
(
err
,
msg
));
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \
#define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX())
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX())
...
@@ -110,21 +115,19 @@ struct hiprtc_program
...
@@ -110,21 +115,19 @@ struct hiprtc_program
std
::
string
cpp_src
=
""
;
std
::
string
cpp_src
=
""
;
std
::
string
cpp_name
=
""
;
std
::
string
cpp_name
=
""
;
hiprtc_program
(
const
std
::
vector
<
src_file
>
&
srcs
)
hiprtc_program
(
std
::
vector
<
hiprtc_
src_file
>
srcs
)
{
{
for
(
auto
&&
src
:
srcs
)
for
(
auto
&&
src
:
srcs
)
{
{
std
::
string
content
{
src
.
content
.
first
,
src
.
content
.
second
};
if
(
ends_with
(
src
.
path
,
".cpp"
))
std
::
string
path
=
src
.
path
.
string
();
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
{
cpp_src
=
std
::
move
(
content
);
cpp_src
=
std
::
move
(
src
.
content
);
cpp_name
=
std
::
move
(
path
);
cpp_name
=
std
::
move
(
src
.
path
);
}
}
else
else
{
{
headers
.
push_back
(
std
::
move
(
content
));
headers
.
push_back
(
std
::
move
(
src
.
content
));
include_names
.
push_back
(
std
::
move
(
path
));
include_names
.
push_back
(
std
::
move
(
src
.
path
));
}
}
}
}
prog
=
hiprtc_program_create
(
cpp_src
.
c_str
(),
prog
=
hiprtc_program_create
(
cpp_src
.
c_str
(),
...
@@ -134,7 +137,7 @@ struct hiprtc_program
...
@@ -134,7 +137,7 @@ struct hiprtc_program
include_names
.
data
());
include_names
.
data
());
}
}
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
)
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
)
const
{
{
if
(
enabled
(
MIGRAPHX_TRACE_HIPRTC
{}))
if
(
enabled
(
MIGRAPHX_TRACE_HIPRTC
{}))
std
::
cout
<<
"hiprtc "
<<
join_strings
(
options
,
" "
)
<<
" "
<<
cpp_name
<<
std
::
endl
;
std
::
cout
<<
"hiprtc "
<<
join_strings
(
options
,
" "
)
<<
" "
<<
cpp_name
<<
std
::
endl
;
...
@@ -175,10 +178,11 @@ struct hiprtc_program
...
@@ -175,10 +178,11 @@ struct hiprtc_program
}
}
};
};
std
::
vector
<
std
::
vector
<
char
>>
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
std
::
string
params
,
const
std
::
string
&
arch
)
{
{
hiprtc_program
prog
(
srcs
);
hiprtc_program
prog
(
std
::
move
(
srcs
)
)
;
auto
options
=
split_string
(
params
,
' '
);
auto
options
=
split_string
(
params
,
' '
);
options
.
push_back
(
"-DMIGRAPHX_USE_HIPRTC=1"
);
options
.
push_back
(
"-DMIGRAPHX_USE_HIPRTC=1"
);
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
...
@@ -187,6 +191,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -187,6 +191,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-unused-parameter"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
}
...
@@ -205,8 +210,50 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -205,8 +210,50 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return
{
prog
.
get_code_obj
()};
return
{
prog
.
get_code_obj
()};
}
}
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
for
(
const
auto
&
src
:
srcs
)
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
())
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
if
(
fs
::
exists
(
driver
))
{
value
v
;
v
[
"srcs"
]
=
to_value
(
hsrcs
);
v
[
"params"
]
=
to_value
(
params
);
v
[
"arch"
]
=
to_value
(
arch
);
tmp_dir
td
{};
auto
out
=
td
.
path
/
"output"
;
process
(
driver
.
string
()
+
" "
+
out
.
string
()).
write
([
&
](
auto
writer
)
{
to_msgpack
(
v
,
writer
);
});
if
(
fs
::
exists
(
out
))
return
{
read_buffer
(
out
.
string
())};
}
return
compile_hip_src_with_hiprtc
(
std
::
move
(
hsrcs
),
std
::
move
(
params
),
arch
);
}
#else // MIGRAPHX_USE_HIPRTC
#else // MIGRAPHX_USE_HIPRTC
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
,
// NOLINT
std
::
string
,
// NOLINT
const
std
::
string
&
)
{
MIGRAPHX_THROW
(
"Not using hiprtc"
);
}
bool
is_hip_clang_compiler
()
bool
is_hip_clang_compiler
()
{
{
static
const
auto
result
=
ends_with
(
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
),
"clang++"
);
static
const
auto
result
=
ends_with
(
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
),
"clang++"
);
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
cd4ab535
...
@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
...
@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
num_elements
=
n
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
groups
=
(
num_elements
+
local
-
1
)
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
return
std
::
min
(
nglobal
,
n
);
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
#ifdef MIGRAPHX_USE_HIPRTC
if
(
enabled
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
{}))
num_elements
=
((
num_elements
+
local
-
1
)
/
local
)
*
local
;
#endif
return
std
::
min
(
nglobal
,
num_elements
);
};
};
}
}
...
@@ -156,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
...
@@ -156,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert
(
not
options
.
inputs
.
empty
());
assert
(
not
options
.
inputs
.
empty
());
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
;
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
std
::
transform
(
migraphx_kernels
().
begin
(),
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
migraphx_kernels
().
end
(),
std
::
back_inserter
(
srcs
),
std
::
back_inserter
(
srcs
),
...
...
src/targets/gpu/compile_ops.cpp
View file @
cd4ab535
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -76,6 +77,109 @@ struct compiled_result
...
@@ -76,6 +77,109 @@ struct compiled_result
instruction_ref
ins
;
instruction_ref
ins
;
};
};
struct
problem_cache
{
bool
has
(
const
std
::
string
&
name
,
const
value
&
problem
)
const
{
return
contains
(
cache
,
create_key
(
name
,
problem
));
}
void
insert
(
const
std
::
string
&
name
,
const
value
&
problem
,
const
value
&
solution
)
{
assert
(
not
solution
.
is_null
());
cache
[
create_key
(
name
,
problem
)]
=
solution
;
}
void
mark
(
const
std
::
string
&
name
,
const
value
&
problem
)
{
cache
.
insert
(
std
::
make_pair
(
create_key
(
name
,
problem
),
value
{}));
}
optional
<
value
>
get
(
const
std
::
string
&
name
,
const
value
&
problem
)
const
{
auto
it
=
cache
.
find
(
create_key
(
name
,
problem
));
if
(
it
==
cache
.
end
())
return
nullopt
;
return
it
->
second
;
}
static
value
create_key
(
const
std
::
string
&
name
,
const
value
&
problem
)
{
return
{{
"name"
,
name
},
{
"problem"
,
problem
}};
}
std
::
unordered_map
<
value
,
value
>
cache
;
};
struct
compile_plan
{
context
*
ctx
;
operation
preop
;
instruction_ref
ins
;
optional
<
tuning_config
>
config
=
nullopt
;
std
::
vector
<
compiled_result
>
results
=
{};
void
update_config
()
{
config
=
get_tuning_config
(
*
ctx
,
ins
,
preop
);
}
template
<
class
Vector
>
void
add_compiles
(
Vector
&
compiles
,
problem_cache
&
pc
)
{
if
(
config
.
has_value
())
{
const
auto
&
problem
=
config
->
problem
;
if
(
auto
sol
=
pc
.
get
(
preop
.
name
(),
problem
))
{
auto
solution
=
sol
.
value
();
// No solution yet until benchmarked so skip for now
if
(
solution
.
is_null
())
return
;
results
.
resize
(
1
);
compiles
.
emplace_back
([
=
]
{
results
[
0
]
=
compiled_result
{
compile
(
*
ctx
,
ins
,
preop
,
solution
),
ins
};
});
}
else
{
pc
.
mark
(
preop
.
name
(),
problem
);
const
auto
&
solutions
=
config
->
solutions
;
results
.
resize
(
solutions
.
size
());
for
(
auto
i
:
range
(
solutions
.
size
()))
{
auto
solution
=
solutions
[
i
];
compiles
.
emplace_back
([
=
]
{
results
[
i
]
=
compiled_result
{
compile
(
*
ctx
,
ins
,
preop
,
solution
),
ins
};
});
}
}
}
else
{
results
.
resize
(
1
);
compiles
.
emplace_back
([
=
]
{
results
[
0
]
=
compiled_result
{
compile
(
*
ctx
,
ins
,
preop
,
value
{}),
ins
};
});
}
}
const
compiled_result
&
benchmark
(
problem_cache
&
pc
)
const
{
if
(
results
.
empty
())
MIGRAPHX_THROW
(
"No configs to tune"
);
if
(
results
.
size
()
==
1
)
return
results
.
front
();
if
(
not
config
)
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
times
.
reserve
(
results
.
size
());
std
::
transform
(
results
.
begin
(),
results
.
end
(),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
)
{
return
time_op
(
*
ctx
,
cr
.
replace
.
code_object
,
to_shapes
(
cr
.
ins
->
inputs
()),
20
).
first
;
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
return
results
[
i
];
}
void
replace
(
module
&
m
,
problem_cache
&
pc
)
const
{
const
auto
&
cr
=
benchmark
(
pc
);
cr
.
replace
.
replace
(
m
,
cr
.
ins
);
}
};
template
<
class
F
>
template
<
class
F
>
void
par_compile
(
std
::
size_t
n
,
F
f
)
void
par_compile
(
std
::
size_t
n
,
F
f
)
{
{
...
@@ -84,25 +188,67 @@ void par_compile(std::size_t n, F f)
...
@@ -84,25 +188,67 @@ void par_compile(std::size_t n, F f)
par_for
(
n
,
n
/
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{},
n
),
f
);
par_for
(
n
,
n
/
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{},
n
),
f
);
}
}
void
compile_
ops
::
apply
(
module
&
m
)
const
struct
compile_
manager
{
{
std
::
vector
<
std
::
function
<
compiled_result
()
>>
compiles
;
problem_cache
pc
;
std
::
vector
<
compile_plan
>
cps
;
bool
exhaustive
=
false
;
template
<
class
...
Ts
>
void
add_plan
(
Ts
&&
...
xs
)
{
cps
.
push_back
({
std
::
forward
<
Ts
>
(
xs
)...});
}
void
update_configs
()
{
if
(
not
exhaustive
)
return
;
par_compile
(
cps
.
size
(),
[
&
](
auto
i
)
{
cps
[
i
].
update_config
();
});
}
void
compile
(
module
&
m
)
{
std
::
vector
<
std
::
function
<
void
()
>>
compiles
;
for
(
auto
&
cp
:
cps
)
{
cp
.
add_compiles
(
compiles
,
pc
);
}
par_compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
compiles
[
i
]();
});
// Replace and/or benchmark
for
(
const
auto
&
cp
:
cps
)
{
if
(
cp
.
results
.
empty
())
continue
;
cp
.
replace
(
m
,
pc
);
}
// Remove compile_plan already executed
cps
.
erase
(
std
::
remove_if
(
cps
.
begin
(),
cps
.
end
(),
[](
const
auto
&
cp
)
{
return
not
cp
.
results
.
empty
();
}),
cps
.
end
());
}
};
void
compile_ops
::
apply
(
module
&
m
)
const
{
compile_manager
cm
;
cm
.
exhaustive
=
exhaustive_tune
;
// Find all precompile opes
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
continue
;
continue
;
operation
preop
=
any_cast
<
precompile_op
>
(
ins
->
get_operator
()).
op
;
operation
preop
=
any_cast
<
precompile_op
>
(
ins
->
get_operator
()).
op
;
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
cm
.
add_plan
(
ctx
,
preop
,
ins
);
return
{
compile
(
*
ctx
,
ins
,
preop
),
ins
};
});
}
std
::
vector
<
compiled_result
>
results
(
compiles
.
size
());
par_compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
for
(
const
auto
&
cr
:
results
)
{
cr
.
replace
(
m
,
cr
.
ins
);
}
}
cm
.
update_configs
();
cm
.
compile
(
m
);
// Compile already tuned configs
cm
.
compile
(
m
);
assert
(
cm
.
cps
.
empty
());
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/compiler.cpp
View file @
cd4ab535
...
@@ -28,33 +28,44 @@ namespace migraphx {
...
@@ -28,33 +28,44 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
auto
&
compiler_map
()
namespace
{
struct
compiler_handle
{
{
static
std
::
unordered_map
<
std
::
string
,
compiler_compile
>
m
;
// NOLINT
compiler_compile
compile
;
return
m
;
compiler_compile_op
compile_op
;
}
compiler_tuning_config
get_tuning_config
;
};
}
// namespace
auto
&
compiler_
op_
map
()
auto
&
compiler_map
()
{
{
static
std
::
unordered_map
<
std
::
string
,
compiler_
compile_op
>
m
;
// NOLINT
static
std
::
unordered_map
<
std
::
string
,
compiler_
handle
>
m
;
// NOLINT
return
m
;
return
m
;
}
}
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
)
void
register_compiler
(
const
std
::
string
&
name
,
compiler_compile
c
,
compiler_compile_op
cop
,
compiler_tuning_config
ctg
)
{
{
compiler_map
()[
name
]
=
std
::
move
(
c
);
compiler_map
()[
name
]
=
{
std
::
move
(
c
),
std
::
move
(
cop
),
std
::
move
(
ctg
)};
compiler_op_map
()[
name
]
=
std
::
move
(
cop
);
}
}
bool
has_compiler_for
(
const
std
::
string
&
name
)
{
return
compiler_map
().
count
(
name
)
>
0
;
}
bool
has_compiler_for
(
const
std
::
string
&
name
)
{
return
compiler_map
().
count
(
name
)
>
0
;
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
const
value
&
solution
)
{
{
return
compiler_map
().
at
(
op
.
name
())(
ctx
,
ins
,
op
);
return
compiler_map
().
at
(
op
.
name
())
.
compile
(
ctx
,
ins
,
op
,
solution
);
}
}
operation
operation
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
compile_op
(
const
std
::
string
&
name
,
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
{
{
return
compiler_op_map
().
at
(
name
)(
ctx
,
inputs
,
v
);
return
compiler_map
().
at
(
name
).
compile_op
(
ctx
,
inputs
,
v
);
}
optional
<
tuning_config
>
get_tuning_config
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
{
return
compiler_map
().
at
(
op
.
name
()).
get_tuning_config
(
ctx
,
ins
,
op
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
cd4ab535
...
@@ -94,6 +94,10 @@ template <>
...
@@ -94,6 +94,10 @@ template <>
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
{
{
};
};
template
<
>
struct
is_hip_type
<
std
::
int32_t
>
:
std
::
true_type
{
};
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
});
});
}
}
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template
<
class
V
,
class
F
,
class
...
Ts
>
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
}
template
<
class
F
>
template
<
class
F
>
...
...
src/targets/gpu/device/multinomial.cpp
View file @
cd4ab535
...
@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
...
@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t
class_size
=
arg0
.
get_shape
().
lens
().
back
();
size_t
class_size
=
arg0
.
get_shape
().
lens
().
back
();
size_t
sample_size
=
result
.
get_shape
().
lens
().
back
();
size_t
sample_size
=
result
.
get_shape
().
lens
().
back
();
hip_visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf
,
auto
dist
)
{
visit_all
(
arg0
,
arg1
)([
&
](
auto
cdf_host
,
auto
dist_host
)
{
result
.
visit
([
&
](
auto
out
)
{
result
.
visit
([
&
](
auto
output_host
)
{
hip_visit_views
(
out
)([
&
](
auto
output
)
{
hip_visit_views
(
cdf_host
,
dist_host
,
output_host
)(
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
[
&
](
auto
cdf
,
auto
dist
,
auto
output
)
{
auto
idx
=
output
.
get_shape
().
multi
(
i
);
gs_launch
(
stream
,
batch_size
*
sample_size
)([
=
](
auto
i
)
__device__
{
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
idx
=
output
.
get_shape
().
multi
(
i
);
auto
cdf_end
=
cdf_begin
+
class_size
;
auto
cdf_begin
=
cdf
.
begin
()
+
(
idx
.
front
()
*
class_size
);
auto
sample_iter
=
auto
cdf_end
=
cdf_begin
+
class_size
;
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
auto
*
sample_iter
=
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
upper_bound
(
cdf_begin
,
cdf_end
,
dist
[
i
]
*
*
(
std
::
prev
(
cdf_end
)));
output
[
i
]
=
std
::
distance
(
cdf_begin
,
sample_iter
);
});
});
});
});
});
});
});
});
}
}
...
...
src/targets/gpu/device/scatter.cpp
View file @
cd4ab535
...
@@ -37,22 +37,26 @@ argument scatter(
...
@@ -37,22 +37,26 @@ argument scatter(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
{
auto
ds
=
arg0
.
get_shape
();
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
hip_visit_all
(
result
,
arg0
,
arg2
)([
&
](
auto
output
,
auto
data
,
auto
update
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
hip_visit_all
(
arg1
)([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
if
constexpr
(
indices
.
get_shape
().
lens
.
size
()
==
output
.
get_shape
().
lens
.
size
())
gs_launch
(
stream
,
inds
.
elements
())([
=
](
auto
i
)
__device__
{
{
auto
out_idx
=
s1
.
multi
(
i
);
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
auto
index
=
indices_ptr
[
i
];
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
gs_launch
(
stream
,
s1
.
elements
())([
=
](
auto
i
)
__device__
{
out_idx
[
axis
]
=
index
;
auto
out_idx
=
indices
.
get_shape
().
multi
(
i
);
output
[
out_idx
]
=
upd_ptr
[
i
];
auto
index
=
indices_ptr
[
i
];
});
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
});
}
});
});
});
});
...
...
src/targets/gpu/device_name.cpp
View file @
cd4ab535
...
@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
...
@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return
std
::
string
(
props
.
gcnArchName
);
return
std
::
string
(
props
.
gcnArchName
);
}
}
std
::
string
get_arch_name
(
const
hipDeviceProp_t
&
props
)
{
return
get_arch_name
(
rank
<
1
>
{},
props
);
}
int
get_device_id
()
int
get_device_id
()
{
{
int
device
;
int
device
;
...
@@ -58,7 +60,7 @@ std::string get_device_name()
...
@@ -58,7 +60,7 @@ std::string get_device_name()
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to get device properties"
);
MIGRAPHX_THROW
(
"Failed to get device properties"
);
return
get_arch_name
(
rank
<
1
>
{},
props
);
return
get_arch_name
(
props
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/driver/CMakeLists.txt
View file @
cd4ab535
...
@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
...
@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable
(
gpu-driver
add_executable
(
gpu-driver
${
GPU_DRIVER_SRCS
}
${
GPU_DRIVER_SRCS
}
)
)
rocm_clang_tidy_check
(
gpu-driver
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment