Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e3e487a1
Commit
e3e487a1
authored
Sep 07, 2022
by
Khalique Ahmed
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into layernorm_eps
parents
b0946229
60aa0e48
Changes
102
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
98 additions
and
37 deletions
+98
-37
src/program.cpp
src/program.cpp
+2
-2
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+1
-1
src/quantization.cpp
src/quantization.cpp
+1
-1
src/rewrite_gelu.cpp
src/rewrite_gelu.cpp
+59
-0
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+3
-3
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+7
-7
src/shape.cpp
src/shape.cpp
+2
-2
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+3
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+2
-2
src/targets/cpu/binary.cpp
src/targets/cpu/binary.cpp
+1
-1
src/targets/fpga/subgraph.cpp
src/targets/fpga/subgraph.cpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+5
-3
src/targets/gpu/device/multinomial.cpp
src/targets/gpu/device/multinomial.cpp
+1
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
...pu/kernels/include/migraphx/kernels/integral_constant.hpp
+3
-3
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-2
No files found.
src/program.cpp
View file @
e3e487a1
...
@@ -78,11 +78,11 @@ program& program::operator=(program p)
...
@@ -78,11 +78,11 @@ program& program::operator=(program p)
void
program
::
assign
(
const
program
&
p
)
void
program
::
assign
(
const
program
&
p
)
{
{
if
(
!
impl
)
if
(
not
impl
)
{
{
impl
=
std
::
make_unique
<
program_impl
>
();
impl
=
std
::
make_unique
<
program_impl
>
();
}
}
else
if
(
!
impl
->
modules
.
empty
())
else
if
(
not
impl
->
modules
.
empty
())
{
{
impl
->
modules
.
clear
();
impl
->
modules
.
clear
();
}
}
...
...
src/py/migraphx_py.cpp
View file @
e3e487a1
...
@@ -83,7 +83,7 @@ void visit_py(T x, F f)
...
@@ -83,7 +83,7 @@ void visit_py(T x, F f)
{
{
f
(
x
.
template
cast
<
bool
>());
f
(
x
.
template
cast
<
bool
>());
}
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
x
)
||
py
::
hasattr
(
x
,
"__index__"
))
else
if
(
py
::
isinstance
<
py
::
int_
>
(
x
)
or
py
::
hasattr
(
x
,
"__index__"
))
{
{
f
(
x
.
template
cast
<
int
>());
f
(
x
.
template
cast
<
int
>());
}
}
...
...
src/quantization.cpp
View file @
e3e487a1
...
@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
...
@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
{
{
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
op_names
=
{
"convolution"
,
"dot"
};
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
std
::
set
<
std
::
string
>
input_ins_names
(
ins_names
.
begin
(),
ins_names
.
end
());
if
(
!
std
::
includes
(
if
(
not
std
::
includes
(
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
op_names
.
begin
(),
op_names
.
end
(),
input_ins_names
.
begin
(),
input_ins_names
.
end
()))
{
{
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
...
...
src/rewrite_gelu.cpp
0 → 100644
View file @
e3e487a1
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_gelu_erf
{
auto
matcher
()
const
{
return
match
::
gelu_erf
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x
=
r
.
instructions
[
"x"
];
if
(
x
->
get_shape
().
type
()
!=
migraphx
::
shape
::
half_type
)
return
;
auto
lit
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.702
f
}});
auto
mul
=
insert_common_op
(
m
,
ins
,
make_op
(
"mul"
),
{
x
,
lit
});
auto
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"neg"
),
mul
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"exp"
),
sig
);
auto
one
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.0
f
}});
sig
=
insert_common_op
(
m
,
ins
,
make_op
(
"add"
),
{
sig
,
one
});
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
sig
);
m
.
replace_instruction
(
ins
,
sig
);
}
};
void
rewrite_gelu
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_gelu_erf
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/rewrite_pooling.cpp
View file @
e3e487a1
...
@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
...
@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
if
(
not
s
.
standard
())
if
(
not
s
.
standard
())
continue
;
continue
;
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
!
std
::
all_of
(
op
.
padding
.
begin
(),
op
.
padding
.
end
(),
[](
auto
i
)
{
return
i
==
0
;
}))
if
(
not
std
::
all_of
(
op
.
padding
.
begin
(),
op
.
padding
.
end
(),
[](
auto
i
)
{
return
i
==
0
;
}))
continue
;
continue
;
if
(
!
std
::
all_of
(
op
.
stride
.
begin
(),
op
.
stride
.
end
(),
[](
auto
i
)
{
return
i
==
1
;
}))
if
(
not
std
::
all_of
(
op
.
stride
.
begin
(),
op
.
stride
.
end
(),
[](
auto
i
)
{
return
i
==
1
;
}))
continue
;
continue
;
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
if
(
!
std
::
equal
(
lens
.
begin
()
+
2
,
lens
.
end
(),
op
.
lengths
.
begin
(),
op
.
lengths
.
end
()))
if
(
not
std
::
equal
(
lens
.
begin
()
+
2
,
lens
.
end
(),
op
.
lengths
.
begin
(),
op
.
lengths
.
end
()))
continue
;
continue
;
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
...
...
src/rewrite_rnn.cpp
View file @
e3e487a1
...
@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
...
@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
if
(
!
is_forward
and
variable_seq_len
)
if
(
not
is_forward
and
variable_seq_len
)
{
{
args
[
0
]
=
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
...
@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
...
@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
if
(
!
is_forward
and
variable_seq_len
)
if
(
not
is_forward
and
variable_seq_len
)
{
{
args
[
0
]
=
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
...
@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
...
@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
pph
=
args
[
7
];
pph
=
args
[
7
];
}
}
if
(
!
is_forward
and
variable_seq_len
)
if
(
not
is_forward
and
variable_seq_len
)
{
{
args
[
0
]
=
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
...
@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
...
@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
std
::
vector
<
int64_t
>
vec_lens
;
std
::
vector
<
int64_t
>
vec_lens
;
arg_lens
.
visit
([
&
](
auto
l
)
{
vec_lens
.
assign
(
l
.
begin
(),
l
.
end
());
});
arg_lens
.
visit
([
&
](
auto
l
)
{
vec_lens
.
assign
(
l
.
begin
(),
l
.
end
());
});
int64_t
l
=
0
;
int64_t
l
=
0
;
if
(
!
vec_lens
.
empty
())
if
(
not
vec_lens
.
empty
())
{
{
l
=
vec_lens
[
0
];
l
=
vec_lens
[
0
];
}
}
if
(
!
std
::
all_of
(
vec_lens
.
begin
(),
vec_lens
.
end
(),
[
&
](
auto
v
)
{
return
v
==
l
;
}))
if
(
not
std
::
all_of
(
vec_lens
.
begin
(),
vec_lens
.
end
(),
[
&
](
auto
v
)
{
return
v
==
l
;
}))
{
{
is_var_lens
=
true
;
is_var_lens
=
true
;
}
}
...
@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
...
@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
bool
is_var_lens
=
is_variable_seq_lens
(
m
,
seq_lens
);
bool
is_var_lens
=
is_variable_seq_lens
(
m
,
seq_lens
);
auto
input_shape
=
input
->
get_shape
();
auto
input_shape
=
input
->
get_shape
();
auto
length
=
input_shape
.
lens
()[
0
];
auto
length
=
input_shape
.
lens
()[
0
];
if
(
!
is_var_lens
and
seq_lens
!=
m
.
end
())
if
(
not
is_var_lens
and
seq_lens
!=
m
.
end
())
{
{
auto
arg_len
=
seq_lens
->
eval
();
auto
arg_len
=
seq_lens
->
eval
();
std
::
vector
<
std
::
size_t
>
vec_lens
;
std
::
vector
<
std
::
size_t
>
vec_lens
;
...
@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
...
@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
if
(
variable_seq_len
)
if
(
variable_seq_len
)
{
{
if
(
!
ins_outputs
.
empty
())
if
(
not
ins_outputs
.
empty
())
{
{
cell_outputs
=
m
.
insert_instruction
(
cell_outputs
=
m
.
insert_instruction
(
std
::
next
(
ins
),
std
::
next
(
ins
),
...
...
src/shape.cpp
View file @
e3e487a1
...
@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
...
@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
::
dynamic_dimension
&
x
)
{
{
...
@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
...
@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
x
.
strides
()
==
y
.
strides
()
and
x
.
sub_shapes
()
==
y
.
sub_shapes
());
}
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
not
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
{
{
...
...
src/simplify_algebra.cpp
View file @
e3e487a1
...
@@ -787,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
...
@@ -787,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
!
(
dots
<
2
and
convs
<
2
);
return
not
(
dots
<
2
and
convs
<
2
);
}
}
struct
find_conv_dot_horiz_fusion
struct
find_conv_dot_horiz_fusion
...
@@ -969,7 +969,7 @@ struct find_split_reshape
...
@@ -969,7 +969,7 @@ struct find_split_reshape
// all outputs are reshape and of the same shape
// all outputs are reshape and of the same shape
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
if
(
!
same_ops
(
vec_rsp
))
if
(
not
same_ops
(
vec_rsp
))
{
{
return
;
return
;
}
}
...
@@ -1052,7 +1052,7 @@ struct find_split_transpose
...
@@ -1052,7 +1052,7 @@ struct find_split_transpose
// all transpose are the same
// all transpose are the same
auto
perm
=
any_cast
<
op
::
transpose
>
(
trans
->
get_operator
()).
dims
;
auto
perm
=
any_cast
<
op
::
transpose
>
(
trans
->
get_operator
()).
dims
;
if
(
!
same_ops
(
vec_trans
))
if
(
not
same_ops
(
vec_trans
))
{
{
return
;
return
;
}
}
...
...
src/simplify_reshapes.cpp
View file @
e3e487a1
...
@@ -99,7 +99,7 @@ struct find_reshaper
...
@@ -99,7 +99,7 @@ struct find_reshaper
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
while
(
is_reshaper
(
reshapes
.
back
()))
{
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
not
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
reshapes
.
push_back
(
input
);
...
@@ -288,7 +288,7 @@ struct find_concat_transpose
...
@@ -288,7 +288,7 @@ struct find_concat_transpose
auto
permutation
=
find_permutation
(
s
);
auto
permutation
=
find_permutation
(
s
);
// permutation should be the same for all inputs
// permutation should be the same for all inputs
if
(
!
std
::
all_of
(
trans_inputs
.
begin
(),
trans_inputs
.
end
(),
[
&
](
auto
in
)
{
if
(
not
std
::
all_of
(
trans_inputs
.
begin
(),
trans_inputs
.
end
(),
[
&
](
auto
in
)
{
return
(
find_permutation
(
in
->
get_shape
())
==
permutation
);
return
(
find_permutation
(
in
->
get_shape
())
==
permutation
);
}))
}))
{
{
...
...
src/targets/cpu/binary.cpp
View file @
e3e487a1
...
@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
...
@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
auto
r
=
s0
;
auto
r
=
s0
;
if
(
s0
!=
s1
or
!
s0
.
packed
())
if
(
s0
!=
s1
or
not
s0
.
packed
())
{
{
r
=
shape
{
s0
.
type
(),
s0
.
lens
()};
r
=
shape
{
s0
.
type
(),
s0
.
lens
()};
}
}
...
...
src/targets/fpga/subgraph.cpp
View file @
e3e487a1
...
@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
...
@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
for
(
auto
it
:
iterator_for
(
mod
))
for
(
auto
it
:
iterator_for
(
mod
))
{
{
// assuming we want all the params/literals as inputs to the FPGA submodule
// assuming we want all the params/literals as inputs to the FPGA submodule
if
(
migraphx
::
starts_with
(
it
->
name
(),
"@param"
)
||
if
(
migraphx
::
starts_with
(
it
->
name
(),
"@param"
)
or
migraphx
::
starts_with
(
it
->
name
(),
"@literal"
))
migraphx
::
starts_with
(
it
->
name
(),
"@literal"
))
{
{
literal_inputs
.
push_back
(
it
);
literal_inputs
.
push_back
(
it
);
...
...
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
View file @
e3e487a1
...
@@ -131,7 +131,7 @@ struct hip_array
...
@@ -131,7 +131,7 @@ struct hip_array
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
!=
(
const
hip_array
&
x
,
const
hip_array
&
y
)
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
!=
(
const
hip_array
&
x
,
const
hip_array
&
y
)
{
{
return
!
(
x
==
y
);
return
not
(
x
==
y
);
}
}
// This uses the product order rather than lexical order
// This uses the product order rather than lexical order
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
<
(
const
hip_array
&
x
,
const
hip_array
&
y
)
friend
MIGRAPHX_DEVICE_CONSTEXPR
bool
operator
<
(
const
hip_array
&
x
,
const
hip_array
&
y
)
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
e3e487a1
...
@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
...
@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_all_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
std
::
initializer_list
<
migraphx
::
shape
::
type_t
>
types
=
{
get_shape
(
xs
).
type
()...};
if
(
!
std
::
all_of
(
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
...
@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
{
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
!
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
}
...
...
src/targets/gpu/device/multinomial.cpp
View file @
e3e487a1
...
@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
...
@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
it
=
first
;
it
=
first
;
step
=
count
/
2
;
step
=
count
/
2
;
std
::
advance
(
it
,
step
);
std
::
advance
(
it
,
step
);
if
(
!
(
value
<
*
it
))
if
(
not
(
value
<
*
it
))
{
{
first
=
++
it
;
first
=
++
it
;
count
-=
step
+
1
;
count
-=
step
+
1
;
...
...
src/targets/gpu/gemm_impl.cpp
View file @
e3e487a1
...
@@ -112,7 +112,7 @@ void gemm_impl(context& ctx,
...
@@ -112,7 +112,7 @@ void gemm_impl(context& ctx,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
!
is_3inputs
)
if
(
not
is_3inputs
)
{
{
beta
=
0
;
beta
=
0
;
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
e3e487a1
...
@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
...
@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
{
{
return
last
;
return
last
;
}
}
if
(
!
(
*
it
==
*
s_it
))
if
(
not
(
*
it
==
*
s_it
))
{
{
break
;
break
;
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
e3e487a1
...
@@ -212,7 +212,7 @@ struct array
...
@@ -212,7 +212,7 @@ struct array
return
true
;
return
true
;
}
}
friend
constexpr
bool
operator
!=
(
const
array
&
x
,
const
array
&
y
)
{
return
!
(
x
==
y
);
}
friend
constexpr
bool
operator
!=
(
const
array
&
x
,
const
array
&
y
)
{
return
not
(
x
==
y
);
}
// This uses the product order rather than lexical order
// This uses the product order rather than lexical order
friend
constexpr
bool
operator
<
(
const
array
&
x
,
const
array
&
y
)
friend
constexpr
bool
operator
<
(
const
array
&
x
,
const
array
&
y
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
View file @
e3e487a1
...
@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
...
@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
^
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
^
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
|
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
|
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
&&
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
and
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
||
)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP
(
or
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
!
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
not
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
~
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
~
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
+
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
+
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
-
)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP
(
-
)
...
...
src/targets/gpu/lowering.cpp
View file @
e3e487a1
...
@@ -341,7 +341,7 @@ struct miopen_apply
...
@@ -341,7 +341,7 @@ struct miopen_apply
catch
(
migraphx
::
exception
&
)
catch
(
migraphx
::
exception
&
)
{
{
// In case no solver supports the default format, retry using the other format.
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format
(
!
int8_x4_format
);
compile_quant_conv_with_format
(
not
int8_x4_format
);
}
}
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
...
...
src/targets/gpu/mlir.cpp
View file @
e3e487a1
...
@@ -78,7 +78,7 @@ struct mlir_handle
...
@@ -78,7 +78,7 @@ struct mlir_handle
friend
bool
operator
==
(
ptr
x
,
ptr
y
)
{
return
x
.
get_value
()
==
y
.
get_value
();
}
friend
bool
operator
==
(
ptr
x
,
ptr
y
)
{
return
x
.
get_value
()
==
y
.
get_value
();
}
friend
bool
operator
!=
(
ptr
x
,
ptr
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
ptr
x
,
ptr
y
)
{
return
not
(
x
==
y
);
}
T
obj
{};
T
obj
{};
};
};
...
@@ -503,7 +503,7 @@ struct mlir_program
...
@@ -503,7 +503,7 @@ struct mlir_program
pp
=
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
std
::
string
tuned
=
get_tune_params
();
std
::
string
tuned
=
get_tune_params
();
if
(
!
tuned
.
empty
())
if
(
not
tuned
.
empty
())
ops
.
add_attributes
({{
"perf_config"
,
tuned
}});
ops
.
add_attributes
({{
"perf_config"
,
tuned
}});
// check if HW supports xdlops
// check if HW supports xdlops
if
(
contains
(
get_xdlops_archs
(),
target_name
))
if
(
contains
(
get_xdlops_archs
(),
target_name
))
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment