Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1612d8f3
Unverified
Commit
1612d8f3
authored
Aug 14, 2023
by
Chris Austen
Committed by
GitHub
Aug 14, 2023
Browse files
Merge branch 'develop' into enable_navi_32_ci
parents
9d3fb0b5
3f98e71d
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
86 additions
and
68 deletions
+86
-68
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+3
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+27
-22
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+1
-1
src/include/migraphx/op/if_op.hpp
src/include/migraphx/op/if_op.hpp
+1
-1
src/include/migraphx/op/loop.hpp
src/include/migraphx/op/loop.hpp
+3
-3
src/instruction.cpp
src/instruction.cpp
+1
-1
src/module.cpp
src/module.cpp
+4
-5
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+1
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+12
-9
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+2
-3
src/onnx/parse_randomuniform_ops.cpp
src/onnx/parse_randomuniform_ops.cpp
+1
-1
src/program.cpp
src/program.cpp
+4
-4
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+1
-0
src/py/include/migraphx/py.hpp
src/py/include/migraphx/py.hpp
+2
-1
src/py/py_loader.cpp
src/py/py_loader.cpp
+1
-1
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+5
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
src/sqlite.cpp
src/sqlite.cpp
+1
-0
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+5
-1
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
+5
-1
No files found.
src/include/migraphx/algorithm.hpp
View file @
1612d8f3
...
...
@@ -26,6 +26,8 @@
#include <algorithm>
#include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace
migraphx
{
...
...
@@ -100,8 +102,7 @@ inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
std
::
vector
<
size_t
>
d
(
l2
+
1
);
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
d
[
j
]
=
j
;
std
::
iota
(
d
.
begin
(),
d
.
end
(),
0
);
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
...
...
src/include/migraphx/check_shapes.hpp
View file @
1612d8f3
...
...
@@ -34,21 +34,37 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// Check that deduced type is incrementable, dereferencable, and comparable
template
<
class
,
class
=
void
>
struct
is_iterator
{
};
template
<
class
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
++
std
::
declval
<
T
&>
()),
decltype
(
*
std
::
declval
<
T
&>
()),
decltype
(
std
::
declval
<
T
&>
()
==
std
::
declval
<
T
&>
())
>>
:
std
::
true_type
{
};
template
<
class
Iterator
>
struct
check_shapes
{
const
shape
*
begin
;
const
shape
*
end
;
static_assert
(
is_iterator
<
Iterator
>
{},
"CHECK_SHAPES: Deduced type must be an iterator"
);
Iterator
begin
;
Iterator
end
;
std
::
string
name
;
bool
dynamic_allowed
;
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
std
::
string
&
n
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
n
),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
,
const
bool
d
=
false
)
check_shapes
(
Iterator
b
,
Iterator
e
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
...
...
@@ -56,7 +72,7 @@ struct check_shapes
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
,
const
bool
d
=
false
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
:
begin
(
s
.
begin
()),
end
(
s
.
end
()),
name
(
op
.
name
()),
dynamic_allowed
(
d
)
{
check_dynamic
();
}
...
...
@@ -81,8 +97,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
0
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
end
-
begin
;
}
...
...
@@ -131,8 +145,6 @@ struct check_shapes
*/
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
!=
n
)
...
...
@@ -148,8 +160,6 @@ struct check_shapes
*/
const
check_shapes
&
max_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
>
n
)
...
...
@@ -166,8 +176,6 @@ struct check_shapes
*/
const
check_shapes
&
min_ndims
(
std
::
size_t
n
)
const
{
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
begin
!=
end
)
{
if
(
begin
->
max_lens
().
size
()
<
n
)
...
...
@@ -330,8 +338,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
auto
&&
key
=
f
(
*
begin
);
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
...
...
@@ -341,8 +347,6 @@ struct check_shapes
{
if
(
begin
==
end
)
return
true
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
all_of
(
begin
,
end
,
p
);
}
...
...
@@ -351,17 +355,13 @@ struct check_shapes
{
if
(
begin
==
end
)
return
false
;
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
return
std
::
any_of
(
begin
,
end
,
p
);
}
const
shape
*
get
(
long
i
)
const
Iterator
get
(
long
i
)
const
{
if
(
i
>=
size
())
MIGRAPHX_THROW
(
prefix
()
+
"Accessing shape out of bounds"
);
assert
(
begin
!=
nullptr
);
assert
(
end
!=
nullptr
);
if
(
i
<
0
)
return
end
-
i
;
return
begin
+
i
;
...
...
@@ -394,6 +394,11 @@ struct check_shapes
}
};
// Deduction guide for std::vector constructor
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
,
const
Op
&
,
bool
d
=
false
)
->
check_shapes
<
std
::
vector
<
shape
>::
const_iterator
>
;
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/op/convolution.hpp
View file @
1612d8f3
...
...
@@ -82,7 +82,7 @@ struct convolution
const
auto
input_ndim
=
inputs
[
0
].
ndim
();
const
auto
padding_size
=
padding
.
size
();
if
(
input_ndim
!=
padding_size
/
2
+
2
&&
input_ndim
!=
padding_size
+
2
)
if
(
input_ndim
!=
padding_size
/
2
+
2
and
input_ndim
!=
padding_size
+
2
)
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
...
...
src/include/migraphx/op/if_op.hpp
View file @
1612d8f3
...
...
@@ -71,7 +71,7 @@ struct if_op
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
std
::
set
<
std
::
string
>
pnames
;
for
(
const
auto
&
smod
:
mods
)
for
(
const
_module_ref
smod
:
mods
)
{
auto
names
=
smod
->
get_parameter_names
();
pnames
.
insert
(
names
.
begin
(),
names
.
end
());
...
...
src/include/migraphx/op/loop.hpp
View file @
1612d8f3
...
...
@@ -59,9 +59,9 @@ struct loop
MIGRAPHX_THROW
(
"LOOP: operator should have one submodule."
);
}
const
auto
&
mod
=
mods
.
front
();
auto
mod_out_shapes
=
mod
->
get_output_shapes
();
auto
dep_param_num
=
inputs
.
size
()
-
2
;
const
_module_ref
mod
=
mods
.
front
();
auto
mod_out_shapes
=
mod
->
get_output_shapes
();
auto
dep_param_num
=
inputs
.
size
()
-
2
;
// first item of the mod output shapes is condition used in loop,
// which is not needed to compute output shape
...
...
src/instruction.cpp
View file @
1612d8f3
...
...
@@ -389,7 +389,7 @@ void instruction::print(std::ostream& os,
if
(
not
ins
->
module_inputs
().
empty
())
{
std
::
string
delim
=
", ["
;
for
(
auto
&
&
mod_arg
:
ins
->
module_inputs
())
for
(
const
const_module_ref
&
mod_arg
:
ins
->
module_inputs
())
{
os
<<
delim
<<
mod_arg
->
name
();
delim
=
", "
;
...
...
src/module.cpp
View file @
1612d8f3
...
...
@@ -873,12 +873,11 @@ module::print_py(std::ostream& os,
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
const
bool
use_abs
=
false
;
// Disable abs for now
use_abs
=
false
;
// ins->get_literal().visit([&](auto v) {
// use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
// });
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_argument("
;
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
1612d8f3
...
...
@@ -117,6 +117,7 @@ struct onnx_parser
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
};
...
...
src/onnx/onnx_parser.cpp
View file @
1612d8f3
...
...
@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
}
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
{
dims
=
parser
.
map_input_dims
.
at
(
name
);
std
::
vector
<
std
::
size_t
>
dims
=
parser
.
map_input_dims
.
at
(
name
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
...
...
@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
}
else
{
s
=
parser
.
parse_type
(
input
.
type
()
,
dims
);
s
=
parser
.
parse_type
(
input
.
type
());
}
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
...
...
@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
}
MIGRAPHX_THROW
(
"PARSE_TENSOR: Invalid tensor type"
);
}
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
)
const
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
not
input_dims
.
empty
())
{
return
{
shape_type
,
input_dims
};
}
std
::
vector
<
shape
::
dynamic_dimension
>
dynamic_dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
...
...
@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return
shape_from_dyn_dims
(
shape_type
,
dynamic_dims
);
}
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
input_dims
.
empty
())
return
{
shape_type
};
return
{
shape_type
,
input_dims
};
}
shape
::
type_t
get_type
(
int
dtype
)
{
switch
(
dtype
)
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
1612d8f3
...
...
@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val
=
literal
({
shape
::
float_type
,
{
1
},
{
0
}},
{
0.0
f
});
}
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
if
(
args
.
empty
())
{
MIGRAPHX_THROW
(
"ConstantOfShape : must have 1 input!"
);
...
...
@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
else
{
migraphx
::
shape
s
;
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
{
...
...
src/onnx/parse_randomuniform_ops.cpp
View file @
1612d8f3
...
...
@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
d
(
high
,
low
);
std
::
uniform_real_distribution
<>
d
(
low
,
high
);
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
...
...
src/program.cpp
View file @
1612d8f3
...
...
@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
auto
mods
=
this
->
get_modules
();
for
(
auto
*
mod
:
mods
)
for
(
const
auto
*
mod
:
mods
)
{
for
(
const
auto
&
ins
:
*
mod
)
{
...
...
@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out
[
x
]
=
ss
.
str
();
});
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
const
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
...
...
@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
std
::
back_inserter
(
module_inputs
),
[
&
](
const
value
&
i
)
{
return
map_mods
.
at
(
i
.
to
<
std
::
string
>
());
});
for
(
auto
&
smod
:
module_inputs
)
for
(
const
auto
&
smod
:
module_inputs
)
{
mod_from_val
(
smod
,
v
,
instructions
,
map_mods
);
}
...
...
@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
std
::
vector
<
module
*>
unused
;
generic_get_unused_modules
(
impl
->
modules
,
generic_get_modules
(
this
->
get_main_module
()),
std
::
back_inserter
(
unused
));
for
(
auto
*
m
:
unused
)
for
(
const
auto
*
m
:
unused
)
this
->
remove_module
(
m
->
name
());
}
...
...
src/py/CMakeLists.txt
View file @
1612d8f3
...
...
@@ -24,6 +24,7 @@
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
add_library
(
migraphx_py py_loader.cpp
)
migraphx_generate_export_header
(
migraphx_py
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
...
...
src/py/include/migraphx/py.hpp
View file @
1612d8f3
...
...
@@ -26,11 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
program
load_py
(
const
std
::
string
&
filename
);
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/py/py_loader.cpp
View file @
1612d8f3
...
...
@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
return
lib
;
}
program
load_py
(
const
std
::
string
&
filename
)
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
)
{
static
auto
f
=
py_lib
().
get_function
<
program
(
const
std
::
string
&
)
>
(
"migraphx_load_py"
);
return
f
(
filename
);
...
...
src/rewrite_quantization.cpp
View file @
1612d8f3
...
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
});
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
auto
s
=
add_zero_point
->
get_shape
();
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
...
...
src/simplify_algebra.cpp
View file @
1612d8f3
...
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
struct
find_conv_dot_horiz_fusion
...
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
// Check that non-axes match
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
axis
=
x
.
size
()
-
1
;
}
...
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
...
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
...
...
src/sqlite.cpp
View file @
1612d8f3
...
...
@@ -48,6 +48,7 @@ struct sqlite_impl
template
<
class
F
>
void
exec
(
const
char
*
sql
,
F
f
)
{
// cppcheck-suppress constParameterPointer
auto
callback
=
[](
void
*
obj
,
auto
...
xs
)
->
int
{
try
{
...
...
src/targets/cpu/gemm.cpp
View file @
1612d8f3
...
...
@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX
(
ARG_BIAS
)};
}
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
not_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
not_broadcasted
();
}
dnnl
::
matmul
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
{
...
...
src/targets/cpu/include/migraphx/cpu/dnnl.hpp
View file @
1612d8f3
...
...
@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
}
// dnnl has some issues with non-packed inputs
void
required
(
const
check_shapes
&
cs
)
const
{
cs
.
packed_or_broadcasted
();
}
template
<
class
T
>
void
required
(
const
check_shapes
<
T
>&
cs
)
const
{
cs
.
packed_or_broadcasted
();
}
std
::
string
name
()
const
{
return
"dnnl::"
+
op
.
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment