Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
032af369
Commit
032af369
authored
Dec 10, 2021
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
b406a418
46b0c33b
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
290 additions
and
32 deletions
+290
-32
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+120
-0
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+9
-6
test/fuse_pointwise.cpp
test/fuse_pointwise.cpp
+29
-0
test/verify/test_hsqrt.cpp
test/verify/test_hsqrt.cpp
+19
-0
tools/roctx.py
tools/roctx.py
+1
-1
tools/test_runner.py
tools/test_runner.py
+111
-25
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
0 → 100644
View file @
032af369
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace
migraphx
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
{
return
x
;
}
}
// namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts> \
auto __device__ name(type x, Ts... xs) MIGRAPHX_RETURNS(fname(x, xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH
(
abs
,
::
abs
)
MIGRAPHX_DEVICE_MATH
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH
(
exp
,
::
exp
)
MIGRAPHX_DEVICE_MATH
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH
(
log
,
::
log
)
MIGRAPHX_DEVICE_MATH
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH
(
rsqrt
,
::
rsqrt
)
MIGRAPHX_DEVICE_MATH
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH
(
sqrt
,
::
sqrt
)
MIGRAPHX_DEVICE_MATH
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH
(
tanh
,
::
tanh
)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
acos
,
::
acosf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
acosh
,
::
acoshf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
asin
,
::
asinf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
asinh
,
::
asinhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
atan
,
::
atanf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
atanh
,
::
atanhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
cos
,
::
cosf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
cosh
,
::
coshf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
rsqrt
,
::
rsqrtf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
sin
,
::
sinf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
sinh
,
::
sinhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
tan
,
::
tanf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
tanh
,
::
tanhf
)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH_HALF
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_HALF
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
template
<
class
T
,
class
U
>
constexpr
auto
&
max
(
const
T
&
a
,
const
U
&
b
)
{
return
(
a
<
b
)
?
b
:
a
;
}
template
<
class
T
,
class
U
>
constexpr
auto
&
min
(
const
T
&
a
,
const
U
&
b
)
{
return
(
a
>
b
)
?
b
:
a
;
}
template
<
class
T
,
class
U
>
constexpr
T
convert
(
U
x
)
{
return
x
;
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
032af369
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/args.hpp>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
032af369
...
@@ -75,7 +75,7 @@ constexpr index_int find_vector_axis(Shapes... ss)
...
@@ -75,7 +75,7 @@ constexpr index_int find_vector_axis(Shapes... ss)
index_int
axis
=
0
;
index_int
axis
=
0
;
bool
b
=
false
;
bool
b
=
false
;
by
([
&
](
auto
s
)
{
by
([
&
](
auto
s
)
{
if
(
s
.
broadcasted
()
or
b
)
if
(
b
)
return
;
return
;
auto
it
=
find
(
s
.
strides
.
begin
(),
s
.
strides
.
end
(),
1
);
auto
it
=
find
(
s
.
strides
.
begin
(),
s
.
strides
.
end
(),
1
);
if
(
it
==
s
.
strides
.
end
())
if
(
it
==
s
.
strides
.
end
())
...
@@ -89,14 +89,17 @@ constexpr index_int find_vector_axis(Shapes... ss)
...
@@ -89,14 +89,17 @@ constexpr index_int find_vector_axis(Shapes... ss)
template
<
index_int
N
,
class
Axis
,
class
...
Shapes
>
template
<
index_int
N
,
class
Axis
,
class
...
Shapes
>
constexpr
auto
is_vectorizable
(
Axis
axis
,
Shapes
...
ss
)
constexpr
auto
is_vectorizable
(
Axis
axis
,
Shapes
...
ss
)
{
{
return
(((
ss
.
lens
[
axis
]
%
N
)
==
0
and
(
ss
.
strides
[
axis
]
==
1
or
ss
.
strides
[
axis
]
==
0
))
and
return
(((
ss
.
lens
[
axis
]
%
N
)
==
0
and
ss
.
strides
[
axis
]
==
1
)
and
...);
...);
}
}
template
<
index_int
N
,
class
...
Shape
s
>
template
<
index_int
N
,
class
Shape
>
constexpr
bool
is_vectorizable
(
Shape
s
...
s
s
)
constexpr
bool
is_vectorizable
(
Shape
s
)
{
{
return
(
is_vectorizable
<
N
>
(
ss
,
find_vector_axis
(
ss
))
and
...);
auto
it
=
find
(
s
.
strides
.
begin
(),
s
.
strides
.
end
(),
1
);
if
(
it
==
s
.
strides
.
end
())
return
false
;
auto
axis
=
it
-
s
.
strides
.
begin
();
return
(
s
.
lens
[
axis
]
%
N
)
==
0
and
s
.
strides
[
axis
]
==
1
;
}
}
template
<
class
P
>
template
<
class
P
>
...
...
test/fuse_pointwise.cpp
View file @
032af369
...
@@ -73,6 +73,35 @@ TEST_CASE(double_add)
...
@@ -73,6 +73,35 @@ TEST_CASE(double_add)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
}
TEST_CASE
(
double_add_without_return
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s
);
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add1
,
z
);
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s
);
auto
fadd
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x
,
y
,
z
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
add1
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
inputs
[
0
],
inputs
[
1
]);
return
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add1
,
inputs
[
2
]);
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
fadd
);
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
used_twice_not_fused
)
TEST_CASE
(
used_twice_not_fused
)
{
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
...
...
test/verify/test_hsqrt.cpp
0 → 100644
View file @
032af369
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_hsqrt
:
verify_program
<
test_hsqrt
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
6
}};
auto
param
=
mm
->
add_parameter
(
"x"
,
s
);
auto
param_abs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"abs"
),
param
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
param_abs
);
return
p
;
}
};
tools/roctx.py
View file @
032af369
...
@@ -21,7 +21,7 @@ def parse_args():
...
@@ -21,7 +21,7 @@ def parse_args():
description
=
"Parser for MIGraphX ROCTX Markers"
)
description
=
"Parser for MIGraphX ROCTX Markers"
)
parser
.
add_argument
(
'--json-path'
,
parser
.
add_argument
(
'--json-path'
,
type
=
str
,
type
=
str
,
metavar
=
'json
_
path'
,
metavar
=
'json
-
path'
,
help
=
'Path to json file'
)
help
=
'Path to json file'
)
parser
.
add_argument
(
'--out'
,
parser
.
add_argument
(
'--out'
,
type
=
str
,
type
=
str
,
...
...
tools/test_runner.py
View file @
032af369
import
os
import
os
,
sys
import
numpy
as
np
import
numpy
as
np
import
argparse
import
argparse
import
onnx
import
onnx
...
@@ -54,36 +54,112 @@ def read_pb_file(filename):
...
@@ -54,36 +54,112 @@ def read_pb_file(filename):
tensor
.
ParseFromString
(
data_str
)
tensor
.
ParseFromString
(
data_str
)
np_array
=
numpy_helper
.
to_array
(
tensor
)
np_array
=
numpy_helper
.
to_array
(
tensor
)
return
np_array
return
tensor
.
name
,
np_array
def
wrapup_inputs
(
io_folder
,
parameter_names
):
def
wrapup_inputs
(
io_folder
,
param_names
):
index
=
0
param_map
=
{}
param_map
=
{}
for
param_name
in
parameter_names
:
data_array
=
[]
file_name
=
io_folder
+
'/input_'
+
str
(
index
)
+
'.pb'
name_array
=
[]
data
=
read_pb_file
(
file_name
)
for
i
in
range
(
len
(
param_names
)):
param_map
[
param_name
]
=
data
file_name
=
io_folder
+
'/input_'
+
str
(
i
)
+
'.pb'
index
=
index
+
1
name
,
data
=
read_pb_file
(
file_name
)
param_map
[
name
]
=
data
data_array
.
append
(
data
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
data_array
):
param_map
=
{}
for
i
in
range
(
len
(
param_names
)):
param_map
[
param_names
[
i
]]
=
data_array
[
i
]
return
param_map
for
name
in
param_names
:
if
not
name
in
param_map
.
keys
():
print
(
"Input {} does not exist!"
.
format
(
name
))
sys
.
exit
()
return
param_map
return
param_map
def
read_outputs
(
io_folder
,
out_n
um
):
def
read_outputs
(
io_folder
,
out_n
ames
):
outputs
=
[]
outputs
=
[]
for
i
in
range
(
out_num
):
data_array
=
[]
name_array
=
[]
for
i
in
range
(
len
(
out_names
)):
file_name
=
io_folder
+
'/output_'
+
str
(
i
)
+
'.pb'
file_name
=
io_folder
+
'/output_'
+
str
(
i
)
+
'.pb'
data
=
read_pb_file
(
file_name
)
name
,
data
=
read_pb_file
(
file_name
)
outputs
.
append
(
data
)
data_array
.
append
(
data
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
data_array
):
return
data_array
for
name
in
out_names
:
index
=
name_array
.
index
(
name
)
outputs
.
append
(
data_array
[
index
])
return
outputs
return
outputs
def
model_parameter_names
(
model_file_name
):
with
open
(
model_file_name
,
'rb'
)
as
pfile
:
data_str
=
pfile
.
read
()
model_proto
=
onnx
.
ModelProto
()
model_proto
.
ParseFromString
(
data_str
)
init_names
=
set
([(
i
.
name
)
for
i
in
model_proto
.
graph
.
initializer
])
param_names
=
[
input
.
name
for
input
in
model_proto
.
graph
.
input
if
input
.
name
not
in
init_names
]
return
param_names
def
model_output_names
(
model_file_name
):
with
open
(
model_file_name
,
'rb'
)
as
pfile
:
data_str
=
pfile
.
read
()
model_proto
=
onnx
.
ModelProto
()
model_proto
.
ParseFromString
(
data_str
)
output_names
=
[
out
.
name
for
out
in
model_proto
.
graph
.
output
]
return
output_names
def
get_input_shapes
(
sample_case
,
param_names
):
param_shape_map
=
{}
name_array
=
[]
shape_array
=
[]
for
i
in
range
(
len
(
param_names
)):
file_name
=
sample_case
+
'/input_'
+
str
(
i
)
+
'.pb'
name
,
data
=
read_pb_file
(
file_name
)
param_shape_map
[
name
]
=
data
.
shape
shape_array
.
append
(
data
.
shape
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
shape_array
):
param_shape_map
=
{}
for
i
in
range
(
len
(
param_names
)):
param_shape_map
[
param_names
[
i
]]
=
shape_array
[
i
]
return
param_shape_map
for
name
in
param_names
:
if
not
name
in
param_shape_map
:
print
(
"Input {} does not exist!"
.
format
(
name
))
sys
.
exit
()
return
param_shape_map
def
run_one_case
(
model
,
param_map
):
def
run_one_case
(
model
,
param_map
):
# convert np array to model argument
# convert np array to model argument
pp
=
{}
pp
=
{}
for
key
,
val
in
param_map
.
items
():
for
key
,
val
in
param_map
.
items
():
print
(
"input = {}"
.
format
(
val
))
pp
[
key
]
=
migraphx
.
argument
(
val
)
pp
[
key
]
=
migraphx
.
argument
(
val
)
# run the model
# run the model
...
@@ -106,12 +182,11 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3):
...
@@ -106,12 +182,11 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3):
out_num
=
len
(
gold_outputs
)
out_num
=
len
(
gold_outputs
)
ret
=
True
ret
=
True
for
i
in
range
(
out_num
):
for
i
in
range
(
out_num
):
print
(
"Expected value:
\n
{}"
.
format
(
gold_outputs
[
i
]))
print
(
"Actual value:
\n
{}"
.
format
(
outputs
[
i
]))
if
not
np
.
allclose
(
gold_outputs
[
i
],
outputs
[
i
],
rtol
,
atol
):
if
not
np
.
allclose
(
gold_outputs
[
i
],
outputs
[
i
],
rtol
,
atol
):
print
(
"Output {} is incorrect ..."
.
format
(
i
))
print
(
"
\n
Output {} is incorrect ..."
.
format
(
i
))
print
(
"Expected value:
\n
{}"
.
format
(
gold_outputs
[
i
]))
print
(
"Expected value:
\n
{}"
.
format
(
gold_outputs
[
i
]))
print
(
"Actual value:
\n
{}"
.
format
(
outputs
[
i
]))
print
(
"......"
)
print
(
"Actual value:
\n
{}
\n
"
.
format
(
outputs
[
i
]))
ret
=
False
ret
=
False
return
ret
return
ret
...
@@ -142,21 +217,32 @@ def main():
...
@@ -142,21 +217,32 @@ def main():
# get model full path
# get model full path
model_name
=
get_model_name
(
test_loc
)
model_name
=
get_model_name
(
test_loc
)
model_path_name
=
test_loc
+
'/'
+
model_name
model_path_name
=
test_loc
+
'/'
+
model_name
# read and compile model
model
=
migraphx
.
parse_onnx
(
model_path_name
)
param_names
=
model
.
get_parameter_names
()
output_shapes
=
model
.
get_output_shapes
()
model
.
compile
(
migraphx
.
get_target
(
target
))
# get param names
param_names
=
model_parameter_names
(
model_path_name
)
# get output names
output_names
=
model_output_names
(
model_path_name
)
# get test cases
# get test cases
cases
=
get_test_cases
(
test_loc
)
cases
=
get_test_cases
(
test_loc
)
sample_case
=
test_loc
+
'/'
+
cases
[
0
]
param_shapes
=
get_input_shapes
(
sample_case
,
param_names
)
for
name
,
dims
in
param_shapes
.
items
():
print
(
"Input: {}, shape: {}"
.
format
(
name
,
dims
))
print
()
# read and compile model
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
model
.
compile
(
migraphx
.
get_target
(
target
))
# get test cases
case_num
=
len
(
cases
)
case_num
=
len
(
cases
)
correct_num
=
0
correct_num
=
0
for
case_name
in
cases
:
for
case_name
in
cases
:
io_folder
=
test_loc
+
'/'
+
case_name
io_folder
=
test_loc
+
'/'
+
case_name
input_data
=
wrapup_inputs
(
io_folder
,
param_names
)
input_data
=
wrapup_inputs
(
io_folder
,
param_names
)
gold_output
_data
=
read_outputs
(
io_folder
,
len
(
output_
shap
es
)
)
gold_output
s
=
read_outputs
(
io_folder
,
output_
nam
es
)
# if input shape is different from model shape, reload and recompile
# if input shape is different from model shape, reload and recompile
# model
# model
...
@@ -170,7 +256,7 @@ def main():
...
@@ -170,7 +256,7 @@ def main():
output_data
=
run_one_case
(
model
,
input_data
)
output_data
=
run_one_case
(
model
,
input_data
)
# check output correctness
# check output correctness
ret
=
check_correctness
(
gold_output
_data
,
output_data
)
ret
=
check_correctness
(
gold_output
s
,
output_data
)
if
ret
:
if
ret
:
correct_num
+=
1
correct_num
+=
1
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment