Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
25e8cf0b
Unverified
Commit
25e8cf0b
authored
Jan 27, 2023
by
Ted Themistokleous
Committed by
GitHub
Jan 27, 2023
Browse files
Merge branch 'develop' into test_onnx_zoo
parents
a313a68e
635502be
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
370 additions
and
85 deletions
+370
-85
src/program.cpp
src/program.cpp
+19
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+19
-19
src/shape.cpp
src/shape.cpp
+50
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+14
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+11
-4
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+2
-2
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+15
-6
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+3
-3
src/targets/gpu/jit/gather.cpp
src/targets/gpu/jit/gather.cpp
+89
-0
src/targets/gpu/jit/mlir.cpp
src/targets/gpu/jit/mlir.cpp
+0
-1
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+13
-4
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
+64
-0
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+19
-11
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+10
-0
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+0
-32
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+32
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
No files found.
src/program.cpp
View file @
25e8cf0b
...
@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
...
@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm
->
print_graph
(
os
,
brief
);
mm
->
print_graph
(
os
,
brief
);
}
}
void
program
::
print_py
(
std
::
ostream
&
os
)
const
{
auto
vec_modules
=
this
->
get_modules
();
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
os
<<
"p = migraphx.program()
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
{
std
::
string
var_name
=
"m"
+
mod
->
name
();
os
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module()"
;
else
os
<<
"p.create_module(
\"
"
<<
mod
->
name
()
<<
"
\"
);"
;
os
<<
std
::
endl
;
names
=
mod
->
print_py
(
os
,
var_name
,
names
);
os
<<
std
::
endl
;
}
}
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
{
{
auto
vec_modules
=
this
->
get_modules
();
auto
vec_modules
=
this
->
get_modules
();
...
...
src/rewrite_rnn.cpp
View file @
25e8cf0b
...
@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length
// process sequence length
instruction_ref
seq_lens
=
m
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"
undefined
"
)
if
((
args
.
size
()
>=
5
)
and
not
args
[
4
]
->
is_
undefined
()
)
{
{
seq_lens
=
args
[
4
];
seq_lens
=
args
[
4
];
}
}
...
@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias
// process bias
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias_forward
=
m
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
...
@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored)
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
{};
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih_forward
=
m
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
...
@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state
// process bias and initial hidden state
instruction_ref
bias
=
m
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// process intial hidden state
// process intial hidden state
instruction_ref
ih
;
instruction_ref
ih
;
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length
// process sequence length
instruction_ref
seq_lens
=
m
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"
undefined
"
)
if
((
args
.
size
()
>=
5
)
and
not
args
[
4
]
->
is_
undefined
()
)
{
{
seq_lens
=
args
[
4
];
seq_lens
=
args
[
4
];
}
}
...
@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias
// bias
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias_forward
=
m
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
...
@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state
// intial hidden state
instruction_ref
ih_forward
{};
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih_forward
=
m
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
...
@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias
// bias
instruction_ref
bias
=
m
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// intial hidden state
// intial hidden state
instruction_ref
ih
{};
instruction_ref
ih
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length
// process sequence length
instruction_ref
seq_lens
=
m
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"
undefined
"
)
if
((
args
.
size
()
>=
5
)
and
not
args
[
4
]
->
is_
undefined
()
)
{
{
seq_lens
=
args
[
4
];
seq_lens
=
args
[
4
];
}
}
...
@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias
// process bias
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias_forward
=
m
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
...
@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument
// process intial hidden state, it is the 6th argument
instruction_ref
ih_forward
{};
instruction_ref
ih_forward
{};
instruction_ref
ih_reverse
{};
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih_forward
=
m
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
...
@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value
// process initial cell value
instruction_ref
ic_forward
{};
instruction_ref
ic_forward
{};
instruction_ref
ic_reverse
{};
instruction_ref
ic_reverse
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_
undefined
()
)
{
{
ic_forward
=
m
.
insert_instruction
(
ic_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
6
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
6
]);
...
@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph_forward
=
m
.
end
();
instruction_ref
pph_forward
=
m
.
end
();
instruction_ref
pph_reverse
=
m
.
end
();
instruction_ref
pph_reverse
=
m
.
end
();
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
8
and
not
args
[
7
]
->
is_
undefined
()
)
{
{
pph_forward
=
m
.
insert_instruction
(
pph_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
7
]);
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
7
]);
...
@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias
// bias
instruction_ref
bias
=
m
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
4
and
not
args
[
3
]
->
is_
undefined
()
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// initial hidden state
// initial hidden state
instruction_ref
ih
{};
instruction_ref
ih
{};
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_
undefined
()
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value
// initial cell value
instruction_ref
ic
{};
instruction_ref
ic
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_
undefined
()
)
{
{
ic
=
args
[
6
];
ic
=
args
[
6
];
}
}
...
@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph
=
m
.
end
();
instruction_ref
pph
=
m
.
end
();
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"
undefined
"
)
if
(
args
.
size
()
==
8
and
not
args
[
7
]
->
is_
undefined
()
)
{
{
pph
=
args
[
7
];
pph
=
args
[
7
];
}
}
...
...
src/shape.cpp
View file @
25e8cf0b
...
@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
...
@@ -504,6 +504,31 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
+=
(
const
std
::
size_t
&
x
)
{
this
->
min
+=
x
;
this
->
max
+=
x
;
if
(
this
->
opt
!=
0
)
{
this
->
opt
+=
x
;
};
return
*
this
;
}
shape
::
dynamic_dimension
&
shape
::
dynamic_dimension
::
operator
-=
(
const
std
::
size_t
&
x
)
{
assert
(
this
->
min
>=
x
);
assert
(
this
->
max
>=
x
);
this
->
min
-=
x
;
this
->
max
-=
x
;
if
(
this
->
opt
!=
0
)
{
assert
(
this
->
opt
>=
x
);
this
->
opt
-=
x
;
}
return
*
this
;
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
// don't check opt if both are fixed
// don't check opt if both are fixed
...
@@ -521,6 +546,31 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
...
@@ -521,6 +546,31 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return
os
;
return
os
;
}
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
x
.
min
==
y
and
x
.
max
==
y
;
}
bool
operator
==
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
y
==
x
;
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
return
not
(
x
==
y
);
}
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
not
(
x
==
y
);
}
shape
::
dynamic_dimension
operator
+
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
auto
dd
=
x
;
return
dd
+=
y
;
}
shape
::
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
return
y
+
x
;
}
shape
::
dynamic_dimension
operator
-
(
const
shape
::
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
)
{
auto
dd
=
x
;
return
dd
-=
y
;
}
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
)
{
{
if
(
x
.
dynamic
()
and
y
.
dynamic
())
if
(
x
.
dynamic
()
and
y
.
dynamic
())
...
...
src/simplify_algebra.cpp
View file @
25e8cf0b
...
@@ -1065,11 +1065,23 @@ struct find_split_reshape
...
@@ -1065,11 +1065,23 @@ struct find_split_reshape
return
;
return
;
}
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
if
(
i
->
outputs
().
size
()
==
1
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
size
()
!=
1
;
}
return
false
;
}))
{
return
;
}
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
assert
(
i
->
outputs
().
size
()
==
1
);
auto
cont
=
i
->
outputs
().
front
();
auto
cont
=
i
->
outputs
().
front
();
assert
(
cont
->
outputs
().
size
()
==
1
);
return
cont
->
outputs
().
front
();
return
cont
->
outputs
().
front
();
});
});
...
...
src/simplify_reshapes.cpp
View file @
25e8cf0b
...
@@ -763,16 +763,23 @@ struct find_transpose_slice
...
@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
// Make unsqeeze
// Make unsqueeze
std
::
vector
<
int64_t
>
steps
(
sdistance
.
size
());
std
::
transform
(
slice
.
axes
.
begin
(),
slice
.
axes
.
end
(),
sdistance
.
begin
(),
steps
.
begin
(),
[
&
](
const
auto
ax
,
const
auto
sdis
)
{
return
ins
->
get_shape
().
lens
().
at
(
ax
)
/
sdis
;
});
auto
unsqueeze
=
m
.
insert_instruction
(
auto
unsqueeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
s
distance
}}),
ins
->
inputs
());
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
s
teps
}}),
ins
->
inputs
());
// Make transpose
// Make transpose
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
if
(
i
>
preaxis
)
if
(
i
>
=
preaxis
)
return
i
+
1
;
return
i
+
1
;
return
i
;
return
i
;
});
});
perm
.
insert
(
perm
.
begin
(),
preaxis
+
1
);
perm
.
insert
(
perm
.
begin
(),
preaxis
);
auto
transpose
=
auto
transpose
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
// Slice and squeeze
// Slice and squeeze
...
...
src/targets/gpu/compile_hip.cpp
View file @
25e8cf0b
...
@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options
.
push_back
(
"-fno-gpu-rdc"
);
options
.
push_back
(
"-fno-gpu-rdc"
);
options
.
push_back
(
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
));
options
.
push_back
(
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
));
options
.
push_back
(
"-Wno-cuda-compat"
);
options
.
push_back
(
"-Wno-cuda-compat"
);
options
.
push_back
(
"--
cuda-gpu
-arch="
+
arch
);
options
.
push_back
(
"--
offload
-arch="
+
arch
);
prog
.
compile
(
options
);
prog
.
compile
(
options
);
return
{
prog
.
get_code_obj
()};
return
{
prog
.
get_code_obj
()};
}
}
...
@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
}
}
else
if
(
is_hip_clang_compiler
())
else
if
(
is_hip_clang_compiler
())
{
{
params
+=
" --
cuda-gpu
-arch="
+
arch
;
params
+=
" --
offload
-arch="
+
arch
;
params
+=
" --cuda-device-only"
;
params
+=
" --cuda-device-only"
;
params
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
params
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
}
}
...
...
src/targets/gpu/hip.cpp
View file @
25e8cf0b
...
@@ -196,12 +196,21 @@ argument to_gpu(const argument& arg, bool host)
...
@@ -196,12 +196,21 @@ argument to_gpu(const argument& arg, bool host)
argument
from_gpu
(
const
argument
&
arg
)
argument
from_gpu
(
const
argument
&
arg
)
{
{
argument
result
;
argument
result
;
arg
.
visit
([
&
](
auto
x
)
{
arg
.
visit
(
using
type
=
typename
decltype
(
x
)
::
value_type
;
[
&
](
auto
x
)
{
auto
v
=
read_from_gpu
<
type
>
(
arg
.
data
(),
x
.
get_shape
().
bytes
()
/
sizeof
(
type
));
using
type
=
typename
decltype
(
x
)
::
value_type
;
// cppcheck-suppress returnDanglingLifetime
auto
v
=
read_from_gpu
<
type
>
(
arg
.
data
(),
x
.
get_shape
().
bytes
()
/
sizeof
(
type
));
result
=
{
x
.
get_shape
(),
[
v
]()
mutable
{
return
v
.
data
();
}};
// cppcheck-suppress returnDanglingLifetime
});
result
=
{
x
.
get_shape
(),
[
v
]()
mutable
{
return
v
.
data
();
}};
},
[
&
](
const
auto
&
xs
)
{
std
::
vector
<
argument
>
args
;
std
::
transform
(
xs
.
begin
(),
xs
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
from_gpu
(
x
);
});
result
=
argument
{
args
};
});
return
result
;
return
result
;
}
}
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
25e8cf0b
...
@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
...
@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std
::
string
name
()
const
{
return
"hip::copy_to_gpu"
;
}
std
::
string
name
()
const
{
return
"hip::copy_to_gpu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
)
.
same_type
()
;
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
...
@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
...
@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std
::
string
name
()
const
{
return
"hip::copy_from_gpu"
;
}
std
::
string
name
()
const
{
return
"hip::copy_from_gpu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
)
.
same_type
()
;
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
argument
...
@@ -159,7 +159,7 @@ struct hip_copy
...
@@ -159,7 +159,7 @@ struct hip_copy
std
::
string
name
()
const
{
return
"hip::copy"
;
}
std
::
string
name
()
const
{
return
"hip::copy"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
)
.
same_type
()
;
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
std
::
vector
<
argument
>
args
)
const
...
...
src/targets/gpu/jit/gather.cpp
0 → 100644
View file @
25e8cf0b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
gather_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
gather_compiler
:
compiler
<
gather_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gather"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
const
auto
&
out_s
=
inputs
.
back
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
out_s
.
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"gather_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
string
>
();
auto
src
=
interpolate_string
(
gather_kernel
,
{{
"axis"
,
axis
}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/mlir.cpp
View file @
25e8cf0b
...
@@ -24,7 +24,6 @@
...
@@ -24,7 +24,6 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/jit/reduce.cpp
View file @
25e8cf0b
...
@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
value
v
=
value
::
object
{};
value
v
=
value
::
object
{};
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
if
(
op
.
name
()
==
"reduce_sum"
)
if
(
op
.
name
()
==
"reduce_sum"
)
{
{
v
[
"reduction"
]
=
"op::sum{}"
;
v
[
"reduction"
]
=
"op::sum{}"
;
}
}
else
if
(
op
.
name
()
==
"reduce_mean"
)
else
if
(
op
.
name
()
==
"reduce_mean"
)
{
{
v
[
"reduction"
]
=
"op::sum{}"
;
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
v
[
"write"
]
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean{"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
else
if
(
contains
({
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
},
reduce_type
))
v
[
"read"
]
=
mean
;
else
v
[
"write"
]
=
mean
;
}
}
else
if
(
op
.
name
()
==
"reduce_max"
)
else
if
(
op
.
name
()
==
"reduce_max"
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
25e8cf0b
...
@@ -187,6 +187,14 @@ constexpr auto fold(F f)
...
@@ -187,6 +187,14 @@ constexpr auto fold(F f)
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
}
}
template
<
class
...
Fs
>
constexpr
auto
compose
(
Fs
...
fs
)
{
return
fold
([](
auto
f
,
auto
g
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
g
(
static_cast
<
decltype
(
xs
)
>
(
xs
)...));
};
})(
fs
...);
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gather.hpp
0 → 100644
View file @
25e8cf0b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace
migraphx
{
template
<
int
Axis
,
class
Input
,
class
Indices
>
constexpr
auto
gather_shape
(
Input
input
,
Indices
indices
)
{
auto
lengths
=
input
.
lens
;
lengths
[
Axis
]
=
indices
.
elements
();
return
make_shape
(
lengths
,
input
.
strides
);
}
template
<
int
Axis
,
class
Input
,
class
Indices
,
class
Output
>
__device__
void
gather
(
Input
input
,
Indices
indices
,
Output
output
)
{
auto
ind
=
make_index
();
auto
axis_dim_size
=
input
.
get_shape
().
lens
[
Axis
];
constexpr
auto
out_comp
=
gather_shape
<
Axis
>
(
get_shape_c
<
Input
>
{},
get_shape_c
<
Indices
>
{});
ind
.
global_stride
(
output
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
auto
idx
=
out_comp
.
multi
(
i
);
auto
in_index
=
indices
[
idx
[
Axis
]];
auto
new_in_index
=
(
in_index
<
0
)
?
in_index
+
axis_dim_size
:
in_index
;
idx
[
Axis
]
=
new_in_index
;
output
[
i
]
=
input
[
idx
];
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
25e8cf0b
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp>
#include <migraphx/kernels/print.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
25e8cf0b
...
@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
...
@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
ceil
,
::
hceil
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
cos
,
::
hcos
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
floor
,
::
hfloor
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
// Use float to compute half overload
...
@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
...
@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_HALF
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_HALF
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_HALF
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
...
@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
...
@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// at this time are: exp2, exp10, log2, log10, isinf
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp
,
::
h2exp
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp
,
::
h2exp
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp10
,
::
h2exp10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp10
,
::
h2exp10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
exp2
,
::
h2exp2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
floor
,
::
h2floor
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log
,
::
h2log
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log
,
::
h2log
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isinf
,
::
__hisinf2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
isnan
,
::
__hisnan2
)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
...
@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b)
...
@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
M_PI_2
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
25e8cf0b
...
@@ -56,6 +56,16 @@ struct id
...
@@ -56,6 +56,16 @@ struct id
}
}
};
};
template
<
class
T
>
struct
convert_to
{
template
<
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
U
x
)
const
{
return
convert
<
T
>
(
x
);
}
};
struct
mean
struct
mean
{
{
index_int
item_num
=
1
;
index_int
item_num
=
1
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
25e8cf0b
...
@@ -33,38 +33,6 @@
...
@@ -33,38 +33,6 @@
namespace
migraphx
{
namespace
migraphx
{
template
<
class
T
>
struct
implicit_conversion_op
{
T
x
;
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
{
if
constexpr
(
vec_size
<
T
>
()
==
0
)
{
return
x
;
}
else
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
}
template
<
class
U
>
constexpr
operator
U
()
const
{
return
x
;
}
};
template
<
class
T
>
constexpr
implicit_conversion_op
<
T
>
implicit_conversion
(
T
x
)
{
return
{
x
};
}
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
25e8cf0b
...
@@ -128,6 +128,7 @@ struct shape
...
@@ -128,6 +128,7 @@ struct shape
result
[
0
]
=
tidx
;
result
[
0
]
=
tidx
;
return
result
;
return
result
;
}
}
/// Convert multi-index into a single index
/// Convert multi-index into a single index
constexpr
index_int
single
(
index_array
idx
)
const
constexpr
index_int
single
(
index_array
idx
)
const
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
25e8cf0b
...
@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
...
@@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
}
}
}
}
template
<
class
T
>
struct
implicit_conversion_op
{
T
x
;
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
{
if
constexpr
(
vec_size
<
T
>
()
==
0
)
{
return
x
;
}
else
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
}
template
<
class
U
>
constexpr
operator
U
()
const
{
return
x
;
}
};
template
<
class
T
>
constexpr
implicit_conversion_op
<
T
>
implicit_conversion
(
T
x
)
{
return
{
x
};
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
src/targets/gpu/lowering.cpp
View file @
25e8cf0b
...
@@ -90,7 +90,6 @@ struct miopen_apply
...
@@ -90,7 +90,6 @@ struct miopen_apply
add_extend_op
(
"argmax"
);
add_extend_op
(
"argmax"
);
add_extend_op
(
"argmin"
);
add_extend_op
(
"argmin"
);
add_extend_op
(
"gather"
);
add_extend_op
(
"logsoftmax"
);
add_extend_op
(
"logsoftmax"
);
add_extend_op
(
"lrn"
);
add_extend_op
(
"lrn"
);
add_extend_op
(
"multinomial"
);
add_extend_op
(
"multinomial"
);
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment