Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5ec8f913
Commit
5ec8f913
authored
Sep 13, 2022
by
Ted Themistokleous
Committed by
Ted Themistokleous
Sep 13, 2022
Browse files
Merge branch 'develop' into simplify_1_mul_div_ops
parents
32d69e8e
d78bcdfb
Changes
183
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3985 additions
and
2898 deletions
+3985
-2898
.github/workflows/ci.yaml
.github/workflows/ci.yaml
+1
-0
.github/workflows/performance.yaml
.github/workflows/performance.yaml
+1
-3
CMakeLists.txt
CMakeLists.txt
+1
-1
cppcheck.rules
cppcheck.rules
+2
-2
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
+2
-2
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+1
-1
examples/vision/cpp_mnist/mnist_inference.cpp
examples/vision/cpp_mnist/mnist_inference.cpp
+6
-6
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+22
-22
src/apply_alpha_beta.cpp
src/apply_alpha_beta.cpp
+1
-1
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+99
-128
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+2468
-1772
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+1355
-935
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+1
-1
src/file_buffer.cpp
src/file_buffer.cpp
+1
-1
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+16
-15
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+2
-2
No files found.
.github/workflows/ci.yaml
View file @
5ec8f913
...
@@ -53,6 +53,7 @@ jobs:
...
@@ -53,6 +53,7 @@ jobs:
CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \
CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \
-DMIGRAPHX_ENABLE_GPU=On \
-DMIGRAPHX_ENABLE_GPU=On \
-DMIGRAPHX_ENABLE_CPU=On \
-DMIGRAPHX_ENABLE_CPU=On \
-DMIGRAPHX_ENABLE_FPGA=On \
-DROCM_ENABLE_GH_ANNOTATIONS=On \
-DROCM_ENABLE_GH_ANNOTATIONS=On \
-DCLANG_TIDY_DEPEND_ON_TARGET=Off \
-DCLANG_TIDY_DEPEND_ON_TARGET=Off \
-DCLANG_TIDY_CACHE=/data/tidy-cache \
-DCLANG_TIDY_CACHE=/data/tidy-cache \
...
...
.github/workflows/performance.yaml
View file @
5ec8f913
name
:
MIGraphX Performance Tests
name
:
MIGraphX Performance Tests
on
:
on
:
push
:
branches
:
[
develop
]
pull_request
:
pull_request
:
branches
:
[
develop
]
branches
:
[
develop
]
types
:
[
opened
,
synchronize
,
closed
]
schedule
:
schedule
:
-
cron
:
"
0
5
*
*
1-6"
-
cron
:
"
0
5
*
*
1-6"
...
...
CMakeLists.txt
View file @
5ec8f913
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include
(
ROCMSetupVersion
)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
3
)
rocm_setup_version
(
VERSION 2.
4
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
cppcheck.rules
View file @
5ec8f913
...
@@ -107,7 +107,7 @@
...
@@ -107,7 +107,7 @@
<summary>
Use make_shared or make_unique instead of new
</summary>
<summary>
Use make_shared or make_unique instead of new
</summary>
</message>
</message>
</rule>
</rule>
<!--
<rule>
<rule>
<tokenlist>
raw
</tokenlist>
<tokenlist>
raw
</tokenlist>
<pattern>
<![CDATA[ \|\| ]]>
</pattern>
<pattern>
<![CDATA[ \|\| ]]>
</pattern>
<message>
<message>
...
@@ -124,7 +124,7 @@
...
@@ -124,7 +124,7 @@
<severity>
style
</severity>
<severity>
style
</severity>
<summary>
Use 'not' instead of !
</summary>
<summary>
Use 'not' instead of !
</summary>
</message>
</message>
</rule>
-->
</rule>
<!-- <rule>
<!-- <rule>
<tokenlist>raw</tokenlist>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
<pattern><![CDATA[] (__device__ |__host__ )+(\(|{)]]></pattern>
...
...
examples/migraphx/cpp_parse_load_save/parse_load_save.cpp
View file @
5ec8f913
...
@@ -53,8 +53,8 @@ int main(int argc, char** argv)
...
@@ -53,8 +53,8 @@ int main(int argc, char** argv)
migraphx
::
program
p
;
migraphx
::
program
p
;
if
(
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--parse"
)
||
if
(
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--parse"
)
or
!
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--load"
))
not
cmdOptionExists
(
argv
+
2
,
argv
+
argc
,
"--load"
))
{
{
std
::
cout
<<
"Parsing ONNX File"
<<
std
::
endl
;
std
::
cout
<<
"Parsing ONNX File"
<<
std
::
endl
;
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
...
...
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
5ec8f913
...
@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
...
@@ -72,7 +72,7 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
{
{
throw
std
::
runtime_error
(
"sscal_custom_op must have 2 input arguments"
);
throw
std
::
runtime_error
(
"sscal_custom_op must have 2 input arguments"
);
}
}
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
||
inputs
[
0
].
lengths
()[
0
]
!=
1
)
if
(
inputs
[
0
].
lengths
().
size
()
!=
1
or
inputs
[
0
].
lengths
()[
0
]
!=
1
)
{
{
throw
std
::
runtime_error
(
"first input argument to sscal_custom_op must be a scalar"
);
throw
std
::
runtime_error
(
"first input argument to sscal_custom_op must be a scalar"
);
}
}
...
...
examples/vision/cpp_mnist/mnist_inference.cpp
View file @
5ec8f913
...
@@ -51,16 +51,16 @@ int main(int argc, char** argv)
...
@@ -51,16 +51,16 @@ int main(int argc, char** argv)
char
**
begin
=
argv
+
1
;
char
**
begin
=
argv
+
1
;
char
**
end
=
argv
+
argc
;
char
**
end
=
argv
+
argc
;
const
bool
CPU
=
(
std
::
find
(
begin
,
end
,
std
::
string
(
"-c"
))
!=
end
)
||
const
bool
CPU
=
(
std
::
find
(
begin
,
end
,
std
::
string
(
"-c"
))
!=
end
)
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--cpu"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--cpu"
))
!=
end
;
const
bool
GPU
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-g"
))
!=
end
||
const
bool
GPU
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-g"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--gpu"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--gpu"
))
!=
end
;
const
bool
FP16
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-f"
))
!=
end
||
const
bool
FP16
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-f"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--fp16"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--fp16"
))
!=
end
;
const
bool
INT8
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-i"
))
!=
end
||
const
bool
INT8
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-i"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--int8"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--int8"
))
!=
end
;
const
bool
CALIB
=
std
::
find
(
begin
,
end
,
std
::
string
(
"--cal"
))
!=
end
;
const
bool
CALIB
=
std
::
find
(
begin
,
end
,
std
::
string
(
"--cal"
))
!=
end
;
const
bool
PRINT
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-p"
))
!=
end
||
const
bool
PRINT
=
std
::
find
(
begin
,
end
,
std
::
string
(
"-p"
))
!=
end
or
std
::
find
(
begin
,
end
,
std
::
string
(
"--print"
))
!=
end
;
std
::
find
(
begin
,
end
,
std
::
string
(
"--print"
))
!=
end
;
migraphx
::
program
prog
;
migraphx
::
program
prog
;
...
@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
...
@@ -182,7 +182,7 @@ void read_nth_digit(const int n, std::vector<float>& digit)
const
int
HEIGHT
=
28
;
const
int
HEIGHT
=
28
;
const
int
WIDTH
=
28
;
const
int
WIDTH
=
28
;
if
(
!
file
.
is_open
())
if
(
not
file
.
is_open
())
{
{
return
;
return
;
}
}
...
...
src/CMakeLists.txt
View file @
5ec8f913
...
@@ -82,6 +82,7 @@ add_library(migraphx
...
@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp
simplify_qdq.cpp
sqlite.cpp
sqlite.cpp
rewrite_batchnorm.cpp
rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
...
@@ -90,7 +91,6 @@ add_library(migraphx
...
@@ -90,7 +91,6 @@ add_library(migraphx
shape.cpp
shape.cpp
simplify_algebra.cpp
simplify_algebra.cpp
simplify_reshapes.cpp
simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp
tmp_dir.cpp
value.cpp
value.cpp
verify_args.cpp
verify_args.cpp
...
...
src/api/include/migraphx/migraphx.hpp
View file @
5ec8f913
...
@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
)
/// Construct a scalar shape
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
shape
(
migraphx_shape_datatype_t
type
)
...
@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -588,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return
pout
;
return
pout
;
}
}
friend
bool
operator
!=
(
const
shape
&
px
,
const
shape
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
shape
&
px
,
const
shape
&
py
)
{
return
not
(
px
==
py
);
}
};
};
/**
/**
...
@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
argument
()
{}
argument
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
...
@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -647,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return
pout
;
return
pout
;
}
}
friend
bool
operator
!=
(
const
argument
&
px
,
const
argument
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
argument
&
px
,
const
argument
&
py
)
{
return
not
(
px
==
py
);
}
};
};
/// A target for compilation
/// A target for compilation
...
@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
...
@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
{
target
()
{}
target
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
)
/// Construct a target from its name
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
program_parameter_shapes
()
{}
program_parameter_shapes
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -684,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std
::
vector
<
const
char
*>
names
()
const
std
::
vector
<
const
char
*>
names
()
const
{
{
std
::
vector
<
const
char
*>
result
(
this
->
size
());
std
::
vector
<
const
char
*>
result
(
this
->
size
());
if
(
!
result
.
empty
())
if
(
not
result
.
empty
())
{
{
call
(
&
migraphx_program_parameter_shapes_names
,
result
.
data
(),
this
->
get_handle_ptr
());
call
(
&
migraphx_program_parameter_shapes_names
,
result
.
data
(),
this
->
get_handle_ptr
());
}
}
...
@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
...
@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
...
@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
...
@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
...
@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
)
};
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
instructions
(
Ts
...
xs
)
...
@@ -797,7 +797,7 @@ struct module;
...
@@ -797,7 +797,7 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
modules
(
Ts
...
xs
)
...
@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
...
@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
)
/// For targets with offloaded memory(such as the gpu), this will insert
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// instructions during compilation to copy the input parameters to the
...
@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
)
/// Compile the program for a specific target to be ran on
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
@@ -1015,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -1015,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return
module
{
p_modu
,
this
->
share_handle
()};
return
module
{
p_modu
,
this
->
share_handle
()};
}
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
!
(
px
==
py
);
}
friend
bool
operator
!=
(
const
program
&
px
,
const
program
&
py
)
{
return
not
(
px
==
py
);
}
};
};
// options for migraphx file format options
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
)
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
// set file format
// set file format
...
@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
)
/// Make onnx parser treat an inputs with a certain dimensions
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
...
@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
)
/// Make tf parser treat an inputs with a certain dimensions
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
...
@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
)
void
add
(
const
std
::
string
&
name
)
void
add
(
const
std
::
string
&
name
)
{
{
...
@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
...
@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
)
/// Add an operator that should be quantized
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/apply_alpha_beta.cpp
View file @
5ec8f913
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto
a
=
args
[
0
];
auto
a
=
args
[
0
];
auto
b
=
args
[
1
];
auto
b
=
args
[
1
];
auto
input_type
=
a
->
get_shape
().
type
();
auto
input_type
=
a
->
get_shape
().
type
();
if
(
!
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
if
(
not
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
{
{
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
...
...
src/driver/alexnet.cpp
View file @
5ec8f913
...
@@ -25,13 +25,10 @@
...
@@ -25,13 +25,10 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/json.hpp>
#include "models.hpp"
#include "models.hpp"
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
migraphx
::
program
alexnet
(
unsigned
batch
)
// NOLINT(readability-function-size)
migraphx
::
program
alexnet
(
unsigned
batch
)
// NOLINT(readability-function-size)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
1
)));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
1
)));
auto
x_main_module_2
=
mmain
->
add_literal
(
migraphx
::
abs
(
auto
x_main_module_2
=
mmain
->
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
2
)));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
2
)));
auto
x_
input_1
=
mmain
->
add_parameter
(
auto
x_
0
=
mmain
->
add_parameter
(
"
input.1
"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch
,
3
,
224
,
224
}});
"
0
"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch
,
3
,
224
,
224
}});
auto
x_main_module_4
=
mmain
->
add_literal
(
auto
x_main_module_4
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
4096
}},
3
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
}},
3
));
auto
x_main_module_5
=
mmain
->
add_literal
(
auto
x_main_module_5
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
4
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
,
4096
}},
4
));
auto
x_main_module_6
=
mmain
->
add_literal
(
auto
x_main_module_6
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
9216
}},
5
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
5
));
auto
x_main_module_7
=
mmain
->
add_literal
(
auto
x_main_module_7
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
6
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
4096
}},
6
));
auto
x_main_module_8
=
mmain
->
add_literal
(
auto
x_main_module_8
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
,
4096
}},
7
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
7
));
auto
x_main_module_9
=
mmain
->
add_literal
(
auto
x_main_module_9
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
}},
8
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
9216
}},
8
));
auto
x_main_module_10
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_10
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
384
,
3
,
3
}},
9
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
9
));
auto
x_main_module_11
=
mmain
->
add_literal
(
auto
x_main_module_11
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
10
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
,
3
,
3
}},
10
));
auto
x_main_module_12
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_12
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
,
192
,
3
,
3
}},
11
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
11
));
auto
x_main_module_13
=
mmain
->
add_literal
(
auto
x_main_module_13
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
}},
12
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
384
,
3
,
3
}},
12
));
auto
x_main_module_14
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_14
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
,
64
,
5
,
5
}},
13
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
}},
13
));
auto
x_main_module_15
=
mmain
->
add_literal
(
auto
x_main_module_15
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
}},
14
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
,
192
,
3
,
3
}},
14
));
auto
x_main_module_16
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_16
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
,
3
,
3
}},
15
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
}},
15
));
auto
x_main_module_17
=
mmain
->
add_literal
(
auto
x_main_module_17
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
16
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
,
64
,
5
,
5
}},
16
));
auto
x_main_module_18
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_18
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
17
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
}},
17
));
auto
x_main_module_19
=
mmain
->
add_literal
(
auto
x_main_module_19
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
}},
18
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
18
));
auto
x_main_module_20
=
mmain
->
add_instruction
(
auto
x_main_module_20
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_json_op
(
"convolution"
,
"convolution"
,
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
migraphx
::
from_json_string
(
"4],use_dynamic_same_auto_pad:0}"
),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"
)),
x_0
,
x_input_1
,
x_main_module_18
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,64,55,55]}"
)),
x_main_module_19
);
x_main_module_19
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,64,55,55]}"
),
x_main_module_18
);
auto
x_main_module_22
=
auto
x_main_module_22
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_20
,
x_main_module_21
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_20
,
x_main_module_21
);
auto
x_main_module_23
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_22
);
auto
x_main_module_23
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_22
);
auto
x_main_module_24
=
mmain
->
add_instruction
(
auto
x_main_module_24
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
x_main_module_23
);
x_main_module_23
);
auto
x_main_module_25
=
mmain
->
add_instruction
(
auto
x_main_module_25
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_json_op
(
"convolution"
,
"convolution"
,
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
migraphx
::
from_json_string
(
"1],use_dynamic_same_auto_pad:0}"
),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"
)),
x_main_module_24
,
x_main_module_24
,
x_main_module_1
4
);
x_main_module_1
7
);
auto
x_main_module_26
=
mmain
->
add_instruction
(
auto
x_main_module_26
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,192,27,27]}"
),
x_main_module_16
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,192,27,27]}"
)),
x_main_module_15
);
auto
x_main_module_27
=
auto
x_main_module_27
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_25
,
x_main_module_26
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_25
,
x_main_module_26
);
auto
x_main_module_28
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_27
);
auto
x_main_module_28
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_27
);
auto
x_main_module_29
=
mmain
->
add_instruction
(
auto
x_main_module_29
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
x_main_module_28
);
x_main_module_28
);
auto
x_main_module_30
=
mmain
->
add_instruction
(
auto
x_main_module_30
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_json_op
(
"convolution"
,
"convolution"
,
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx
::
from_json_string
(
"1],use_dynamic_same_auto_pad:0}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
x_main_module_29
,
x_main_module_29
,
x_main_module_1
2
);
x_main_module_1
5
);
auto
x_main_module_31
=
mmain
->
add_instruction
(
auto
x_main_module_31
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,384,13,13]}"
),
x_main_module_14
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,384,13,13]}"
)),
x_main_module_13
);
auto
x_main_module_32
=
auto
x_main_module_32
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_34
=
mmain
->
add_instruction
(
auto
x_main_module_34
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_json_op
(
"convolution"
,
"convolution"
,
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx
::
from_json_string
(
"1],use_dynamic_same_auto_pad:0}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
x_main_module_33
,
x_main_module_33
,
x_main_module_1
0
);
x_main_module_1
3
);
auto
x_main_module_35
=
mmain
->
add_instruction
(
auto
x_main_module_35
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,256,13,13]}"
),
x_main_module_12
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,256,13,13]}"
)),
x_main_module_11
);
auto
x_main_module_36
=
auto
x_main_module_36
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_38
=
mmain
->
add_instruction
(
auto
x_main_module_38
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_json_op
(
"convolution"
,
"convolution"
,
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx
::
from_json_string
(
"1],use_dynamic_same_auto_pad:0}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
x_main_module_37
,
x_main_module_37
,
x_main_module_1
6
);
x_main_module_1
1
);
auto
x_main_module_39
=
mmain
->
add_instruction
(
auto
x_main_module_39
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,256,13,13]}"
),
x_main_module_10
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,256,13,13]}"
)),
x_main_module_17
);
auto
x_main_module_40
=
auto
x_main_module_40
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_38
,
x_main_module_39
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_38
,
x_main_module_39
);
auto
x_main_module_41
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_40
);
auto
x_main_module_41
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_40
);
auto
x_main_module_42
=
mmain
->
add_instruction
(
auto
x_main_module_42
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
x_main_module_41
);
x_main_module_41
);
auto
x_main_module_43
=
mmain
->
add_instruction
(
auto
x_main_module_43
=
migraphx
::
make_op
(
"reshape"
,
migraphx
::
from_json_string
(
"{dims:[1,9216]}"
)),
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"flatten"
,
"{axis:1}"
),
x_main_module_42
);
x_main_module_42
);
auto
x_main_module_44
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_main_module_43
);
auto
x_main_module_44
=
mmain
->
add_instruction
(
auto
x_main_module_45
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_9
);
x_main_module_6
);
auto
x_main_module_46
=
auto
x_main_module_45
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_44
,
x_main_module_45
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_43
,
x_main_module_44
);
auto
x_main_module_46
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
x_main_module_7
);
auto
x_main_module_47
=
mmain
->
add_instruction
(
auto
x_main_module_47
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_8
);
x_main_module_2
);
auto
x_main_module_48
=
mmain
->
add_instruction
(
auto
x_main_module_48
=
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_2
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_46
,
x_main_module_47
);
auto
x_main_module_49
=
auto
x_main_module_49
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_45
,
x_main_module_48
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_47
,
x_main_module_48
);
auto
x_main_module_50
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_49
);
auto
x_main_module_50
=
auto
x_main_module_51
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_46
,
x_main_module_49
);
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
auto
x_main_module_51
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_50
);
x_main_module_4
);
auto
x_main_module_52
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_main_module_51
);
auto
x_main_module_52
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_50
,
x_main_module_51
);
auto
x_main_module_53
=
mmain
->
add_instruction
(
auto
x_main_module_53
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_7
);
x_main_module_5
);
auto
x_main_module_54
=
auto
x_main_module_54
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_52
,
x_main_module_53
);
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
auto
x_main_module_55
=
mmain
->
add_instruction
(
x_main_module_1
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_6
);
auto
x_main_module_55
=
auto
x_main_module_56
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_53
,
x_main_module_54
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_1
);
auto
x_main_module_56
=
auto
x_main_module_57
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_52
,
x_main_module_55
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_55
,
x_main_module_56
);
auto
x_main_module_57
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_56
);
auto
x_main_module_58
=
auto
x_main_module_58
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_54
,
x_main_module_57
);
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
auto
x_main_module_59
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_58
);
x_main_module_8
);
auto
x_main_module_59
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_57
,
x_main_module_58
);
auto
x_main_module_60
=
mmain
->
add_instruction
(
auto
x_main_module_60
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,1000]}"
)),
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_5
);
x_main_module_9
);
auto
x_main_module_61
=
auto
x_main_module_61
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_59
,
x_main_module_60
);
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,1000]}"
)),
auto
x_main_module_62
=
mmain
->
add_instruction
(
x_main_module_0
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,1000]}"
),
x_main_module_4
);
auto
x_main_module_62
=
auto
x_main_module_63
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_60
,
x_main_module_61
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,1000]}"
),
x_main_module_0
);
auto
x_main_module_63
=
auto
x_main_module_64
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_59
,
x_main_module_62
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_62
,
x_main_module_63
);
mmain
->
add_return
({
x_main_module_63
});
auto
x_main_module_65
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_61
,
x_main_module_64
);
mmain
->
add_return
({
x_main_module_65
});
return
p
;
return
p
;
}
}
...
...
src/driver/inceptionv3.cpp
View file @
5ec8f913
This source diff could not be displayed because it is too large. You can
view the blob
instead.
src/driver/resnet50.cpp
View file @
5ec8f913
This diff is collapsed.
Click to expand it.
src/eliminate_concat.cpp
View file @
5ec8f913
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
if
(
axis_index
==
0
||
if
(
axis_index
==
0
or
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
{
{
// Last input should be an allocation
// Last input should be an allocation
...
...
src/eliminate_contiguous.cpp
View file @
5ec8f913
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
!
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
{
{
return
false
;
return
false
;
}
}
...
...
src/file_buffer.cpp
View file @
5ec8f913
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
T
buffer
(
size
,
0
);
if
(
!
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
return
buffer
;
}
}
...
...
src/include/migraphx/allocation_model.hpp
View file @
5ec8f913
...
@@ -205,7 +205,7 @@ struct allocation_model
...
@@ -205,7 +205,7 @@ struct allocation_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -267,7 +267,7 @@ struct allocation_model
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/check_shapes.hpp
View file @
5ec8f913
...
@@ -101,7 +101,7 @@ struct check_shapes
...
@@ -101,7 +101,7 @@ struct check_shapes
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -164,7 +164,7 @@ struct check_shapes
...
@@ -164,7 +164,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_shape
()
const
const
check_shapes
&
same_shape
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -174,7 +174,7 @@ struct check_shapes
...
@@ -174,7 +174,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_type
()
const
const
check_shapes
&
same_type
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -184,10 +184,10 @@ struct check_shapes
...
@@ -184,10 +184,10 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_dims
()
const
const
check_shapes
&
same_dims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -197,7 +197,7 @@ struct check_shapes
...
@@ -197,7 +197,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_ndims
()
const
const
check_shapes
&
same_ndims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -207,7 +207,7 @@ struct check_shapes
...
@@ -207,7 +207,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
standard
()
const
const
check_shapes
&
standard
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -217,7 +217,7 @@ struct check_shapes
...
@@ -217,7 +217,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
standard_or_scalar
()
const
const
check_shapes
&
standard_or_scalar
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -227,7 +227,7 @@ struct check_shapes
...
@@ -227,7 +227,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
packed
()
const
const
check_shapes
&
packed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -237,7 +237,7 @@ struct check_shapes
...
@@ -237,7 +237,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
packed_or_broadcasted
()
const
const
check_shapes
&
packed_or_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -247,7 +247,7 @@ struct check_shapes
...
@@ -247,7 +247,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
tuple_type
()
const
const
check_shapes
&
tuple_type
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -257,7 +257,7 @@ struct check_shapes
...
@@ -257,7 +257,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
not_transposed
()
const
const
check_shapes
&
not_transposed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -267,7 +267,7 @@ struct check_shapes
...
@@ -267,7 +267,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
not_broadcasted
()
const
const
check_shapes
&
not_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -278,7 +278,7 @@ struct check_shapes
...
@@ -278,7 +278,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -288,7 +288,8 @@ struct check_shapes
...
@@ -288,7 +288,8 @@ struct check_shapes
*/
*/
const
check_shapes
&
batch_not_transposed
()
const
const
check_shapes
&
batch_not_transposed
()
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
not
this
->
all_of
(
[
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
return
*
this
;
return
*
this
;
}
}
...
...
src/include/migraphx/concat_opt.hpp
View file @
5ec8f913
...
@@ -183,7 +183,7 @@ struct concat_optimization
...
@@ -183,7 +183,7 @@ struct concat_optimization
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -233,7 +233,7 @@ struct concat_optimization
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/context.hpp
View file @
5ec8f913
...
@@ -246,7 +246,7 @@ struct context
...
@@ -246,7 +246,7 @@ struct context
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -306,7 +306,7 @@ struct context
...
@@ -306,7 +306,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
Prev
1
2
3
4
5
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment