Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c5d8c71c
Commit
c5d8c71c
authored
Aug 16, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu
parents
b878f78f
0e17a724
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
111 additions
and
27 deletions
+111
-27
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+18
-18
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+5
-3
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+0
-4
test/verify/test_softmax_large1.cpp
test/verify/test_softmax_large1.cpp
+43
-0
test/verify/test_softmax_large2.cpp
test/verify/test_softmax_large2.cpp
+43
-0
No files found.
src/api/include/migraphx/migraphx.hpp
View file @
c5d8c71c
...
@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
)
/// Construct a scalar shape
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
shape
(
migraphx_shape_datatype_t
type
)
...
@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
argument
()
{}
argument
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
...
@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
...
@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
{
target
()
{}
target
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
)
/// Construct a target from its name
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
program_parameter_shapes
()
{}
program_parameter_shapes
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
...
@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
...
@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
...
@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
...
@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
)
};
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
instructions
(
Ts
...
xs
)
...
@@ -797,7 +797,7 @@ struct module;
...
@@ -797,7 +797,7 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
modules
(
Ts
...
xs
)
...
@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
...
@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
)
/// For targets with offloaded memory(such as the gpu), this will insert
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// instructions during compilation to copy the input parameters to the
...
@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
)
/// Compile the program for a specific target to be ran on
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
@@ -1021,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -1021,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
)
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
// set file format
// set file format
...
@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
)
/// Make onnx parser treat an inputs with a certain dimensions
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
...
@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
)
/// Make tf parser treat an inputs with a certain dimensions
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
...
@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
)
void
add
(
const
std
::
string
&
name
)
void
add
(
const
std
::
string
&
name
)
{
{
...
@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
...
@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
)
/// Add an operator that should be quantized
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
c5d8c71c
...
@@ -90,7 +90,7 @@ struct lowest
...
@@ -90,7 +90,7 @@ struct lowest
template
<
class
T
>
template
<
class
T
>
constexpr
operator
T
()
const
constexpr
operator
T
()
const
{
{
return
numeric_lowest
<
T
>
();
return
numeric_lowest
<
vec_type
<
T
>
>
();
}
}
};
};
...
@@ -99,7 +99,7 @@ struct highest
...
@@ -99,7 +99,7 @@ struct highest
template
<
class
T
>
template
<
class
T
>
constexpr
operator
T
()
const
constexpr
operator
T
()
const
{
{
return
numeric_max
<
T
>
();
return
numeric_max
<
vec_type
<
T
>
>
();
}
}
};
};
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
c5d8c71c
...
@@ -192,9 +192,13 @@ struct common_type<T, U, Us...>
...
@@ -192,9 +192,13 @@ struct common_type<T, U, Us...>
template
<
class
...
Ts
>
template
<
class
...
Ts
>
using
common_type_t
=
typename
common_type
<
Ts
...
>::
type
;
using
common_type_t
=
typename
common_type
<
Ts
...
>::
type
;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{})
>
constexpr
T
numeric_max
()
constexpr
T
numeric_max
()
{
{
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
...
@@ -230,8 +234,6 @@ constexpr T numeric_lowest()
...
@@ -230,8 +234,6 @@ constexpr T numeric_lowest()
}
}
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
test/onnx/onnx_test.cpp
View file @
c5d8c71c
...
@@ -954,13 +954,11 @@ TEST_CASE(conv_dynamic_img_same_upper)
...
@@ -954,13 +954,11 @@ TEST_CASE(conv_dynamic_img_same_upper)
TEST_CASE
(
conv_dynamic_kernel_same_lower
)
TEST_CASE
(
conv_dynamic_kernel_same_lower
)
{
{
std
::
cout
<<
"here1
\n
"
;
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"0"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l0
=
mm
->
add_parameter
(
"0"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l1
=
mm
->
add_parameter
(
auto
l1
=
mm
->
add_parameter
(
"1"
,
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
,
0
},
{
3
,
3
,
0
},
{
2
,
4
,
0
},
{
2
,
4
,
0
}}});
"1"
,
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
,
0
},
{
3
,
3
,
0
},
{
2
,
4
,
0
},
{
2
,
4
,
0
}}});
std
::
cout
<<
"here2
\n
"
;
auto
c0
=
mm
->
add_instruction
(
auto
c0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
}},
{{
"padding"
,
{
0
,
0
}},
...
@@ -970,12 +968,10 @@ TEST_CASE(conv_dynamic_kernel_same_lower)
...
@@ -970,12 +968,10 @@ TEST_CASE(conv_dynamic_kernel_same_lower)
{
"use_dynamic_same_auto_pad"
,
true
}}),
{
"use_dynamic_same_auto_pad"
,
true
}}),
l0
,
l0
,
l1
);
l1
);
std
::
cout
<<
"here3
\n
"
;
mm
->
add_return
({
c0
});
mm
->
add_return
({
c0
});
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
2
,
4
,
0
};
options
.
default_dyn_dim_value
=
{
2
,
4
,
0
};
std
::
cout
<<
"here
\n
"
;
auto
prog
=
migraphx
::
parse_onnx
(
"conv_dynamic_kernel_same_lower_test.onnx"
,
options
);
auto
prog
=
migraphx
::
parse_onnx
(
"conv_dynamic_kernel_same_lower_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
...
...
test/verify/test_softmax_large1.cpp
0 → 100644
View file @
c5d8c71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
struct
test_softmax_large1
:
verify_program
<
test_softmax_large1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}});
auto
large
=
mm
->
add_literal
({
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
{
10000
}});
auto
add
=
migraphx
::
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
x
,
large
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
add
);
return
p
;
}
};
test/verify/test_softmax_large2.cpp
0 → 100644
View file @
c5d8c71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
struct
test_softmax_large2
:
verify_program
<
test_softmax_large2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}});
auto
large
=
mm
->
add_literal
({
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
{
-
10000
}});
auto
add
=
migraphx
::
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
x
,
large
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
add
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment