Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
20943acb
Commit
20943acb
authored
Oct 30, 2009
by
Peter Eastman
Browse files
Added test case for Parser
parent
df51e651
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
226 additions
and
0 deletions
+226
-0
tests/TestParser.cpp
tests/TestParser.cpp
+226
-0
No files found.
tests/TestParser.cpp
0 → 100644
View file @
20943acb
#include "../libraries/lepton/include/Lepton.h"
#include <iostream>
#include <limits>
#include <map>
using
namespace
Lepton
;
using
namespace
std
;
#define ASSERT_EQUAL_TOL(expected, found, tol) {double _scale_ = std::fabs(expected) > 1.0 ? std::fabs(expected) : 1.0; if (!(std::fabs((expected)-(found))/_scale_ <= (tol))) throw exception();};
/**
* This is a custom function equal to f(x,y) = 2*x*y.
*/
class
ExampleFunction
:
public
CustomFunction
{
int
getNumArguments
()
const
{
return
2
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
2.0
*
arguments
[
0
]
*
arguments
[
1
];
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
if
(
derivOrder
[
0
]
==
1
)
{
if
(
derivOrder
[
1
]
==
0
)
return
2.0
*
arguments
[
1
];
else
if
(
derivOrder
[
1
]
==
1
)
return
2.0
;
else
return
0.0
;
}
if
(
derivOrder
[
1
]
==
1
)
{
if
(
derivOrder
[
0
]
==
0
)
return
2.0
*
arguments
[
0
];
else
return
0.0
;
}
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
ExampleFunction
();
}
};
/**
* Verify that an expression gives the correct value.
*/
void
verifyEvaluation
(
const
string
&
expression
,
double
expectedValue
)
{
map
<
string
,
CustomFunction
*>
customFunctions
;
ParsedExpression
parsed
=
Parser
::
parse
(
expression
,
customFunctions
);
double
value
=
parsed
.
evaluate
();
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
// Try optimizing it and make sure the result is still correct.
value
=
parsed
.
optimize
().
evaluate
();
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
// Create an ExpressionProgram and see if that also gives the same result.
ExpressionProgram
program
=
parsed
.
createProgram
();
value
=
program
.
evaluate
();
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
}
/**
* Verify that an expression with variables gives the correct value.
*/
void
verifyEvaluation
(
const
string
&
expression
,
double
x
,
double
y
,
double
expectedValue
)
{
map
<
string
,
double
>
variables
;
variables
[
"x"
]
=
x
;
variables
[
"y"
]
=
y
;
ParsedExpression
parsed
=
Parser
::
parse
(
expression
);
double
value
=
parsed
.
evaluate
(
variables
);
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
// Try optimizing it and make sure the result is still correct.
value
=
parsed
.
optimize
().
evaluate
(
variables
);
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
// Try optimizing with predefined values for the variables.
value
=
parsed
.
optimize
(
variables
).
evaluate
();
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
// Create an ExpressionProgram and see if that also gives the same result.
ExpressionProgram
program
=
parsed
.
createProgram
();
value
=
program
.
evaluate
(
variables
);
ASSERT_EQUAL_TOL
(
expectedValue
,
value
,
1e-10
);
}
/**
* Confirm that a parse error gets thrown.
*/
void
verifyInvalidExpression
(
const
string
&
expression
)
{
try
{
Parser
::
parse
(
expression
);
}
catch
(
const
exception
&
ex
)
{
return
;
}
throw
exception
();
}
/**
* Verify that two expressions give the same value.
*/
void
verifySameValue
(
const
ParsedExpression
&
exp1
,
const
ParsedExpression
&
exp2
,
double
x
,
double
y
)
{
map
<
string
,
double
>
variables
;
variables
[
"x"
]
=
x
;
variables
[
"y"
]
=
y
;
double
val1
=
exp1
.
evaluate
(
variables
);
double
val2
=
exp2
.
evaluate
(
variables
);
const
double
inf
=
numeric_limits
<
double
>::
infinity
();
if
(
val1
==
val1
||
val2
==
val2
)
// If both are NaN, that's fine.
if
(
val1
!=
inf
||
val2
!=
inf
)
// Both infinity is also fine.
if
(
val1
!=
-
inf
||
val2
!=
-
inf
)
// Same for -infinity.
ASSERT_EQUAL_TOL
(
val1
,
val2
,
1e-10
);
}
/**
* Verify that the derivative of an expression is calculated correctly.
*/
void
verifyDerivative
(
const
string
&
expression
,
const
string
&
expectedDeriv
)
{
ParsedExpression
computed
=
Parser
::
parse
(
expression
).
differentiate
(
"x"
).
optimize
();
ParsedExpression
expected
=
Parser
::
parse
(
expectedDeriv
);
verifySameValue
(
computed
,
expected
,
1.0
,
2.0
);
verifySameValue
(
computed
,
expected
,
2.0
,
3.0
);
verifySameValue
(
computed
,
expected
,
-
2.0
,
3.0
);
verifySameValue
(
computed
,
expected
,
2.0
,
-
3.0
);
}
/**
* Test the use of a custom function.
*/
void
testCustomFunction
(
const
string
&
expression
,
const
string
&
equivalent
)
{
map
<
string
,
CustomFunction
*>
functions
;
functions
[
"custom"
]
=
new
ExampleFunction
();
ParsedExpression
exp1
=
Parser
::
parse
(
expression
,
functions
);
ParsedExpression
exp2
=
Parser
::
parse
(
equivalent
);
verifySameValue
(
exp1
,
exp2
,
1.0
,
2.0
);
verifySameValue
(
exp1
,
exp2
,
2.0
,
3.0
);
verifySameValue
(
exp1
,
exp2
,
-
2.0
,
3.0
);
verifySameValue
(
exp1
,
exp2
,
2.0
,
-
3.0
);
ParsedExpression
deriv1
=
exp1
.
differentiate
(
"x"
);
ParsedExpression
deriv2
=
exp2
.
differentiate
(
"x"
);
verifySameValue
(
deriv1
,
deriv2
,
1.0
,
2.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
-
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
-
3.0
);
ParsedExpression
deriv3
=
deriv1
.
differentiate
(
"y"
);
ParsedExpression
deriv4
=
deriv2
.
differentiate
(
"y"
);
verifySameValue
(
deriv1
,
deriv2
,
1.0
,
2.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
-
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
-
3.0
);
delete
functions
[
"custom"
];
}
int
main
()
{
try
{
verifyEvaluation
(
"5"
,
5.0
);
verifyEvaluation
(
"5*2"
,
10.0
);
verifyEvaluation
(
"2*3+4*5"
,
26.0
);
verifyEvaluation
(
"2^-3"
,
0.125
);
verifyEvaluation
(
"-x"
,
2.0
,
3.0
,
-
2.0
);
verifyEvaluation
(
"y^-x"
,
3.0
,
2.0
,
0.125
);
verifyEvaluation
(
"1/-x"
,
3.0
,
2.0
,
-
1.0
/
3.0
);
verifyEvaluation
(
"2.1e-4*x*(y+1)"
,
3.0
,
1.0
,
1.26e-3
);
verifyEvaluation
(
"sin(2.5)"
,
std
::
sin
(
2.5
));
verifyEvaluation
(
"cot(x)"
,
3.0
,
1.0
,
1.0
/
std
::
tan
(
3.0
));
verifyEvaluation
(
"x^2+y^3+x^-1+y^(1/2)"
,
1.0
,
1.0
,
4.0
);
verifyEvaluation
(
"(2*x)*3"
,
4.0
,
4.0
,
24.0
);
verifyEvaluation
(
"(x*2)*3"
,
4.0
,
4.0
,
24.0
);
verifyEvaluation
(
"2*(x*3)"
,
4.0
,
4.0
,
24.0
);
verifyEvaluation
(
"2*(3*x)"
,
4.0
,
4.0
,
24.0
);
verifyEvaluation
(
"2*x/3"
,
1.0
,
4.0
,
2.0
/
3.0
);
verifyEvaluation
(
"x*2/3"
,
1.0
,
4.0
,
2.0
/
3.0
);
verifyInvalidExpression
(
"1..2"
);
verifyInvalidExpression
(
"1*(2+3"
);
verifyInvalidExpression
(
"5++4"
);
verifyInvalidExpression
(
"1+2)"
);
verifyInvalidExpression
(
"cos(2,3)"
);
verifyDerivative
(
"x"
,
"1"
);
verifyDerivative
(
"x^2+x"
,
"2*x+1"
);
verifyDerivative
(
"y^x-x"
,
"log(y)*(y^x)-1"
);
verifyDerivative
(
"sin(x)"
,
"cos(x)"
);
verifyDerivative
(
"cos(x)"
,
"-sin(x)"
);
verifyDerivative
(
"tan(x)"
,
"square(sec(x))"
);
verifyDerivative
(
"cot(x)"
,
"-square(csc(x))"
);
verifyDerivative
(
"sec(x)"
,
"sec(x)*tan(x)"
);
verifyDerivative
(
"csc(x)"
,
"-csc(x)*cot(x)"
);
verifyDerivative
(
"exp(2*x)"
,
"2*exp(2*x)"
);
verifyDerivative
(
"log(x)"
,
"1/x"
);
verifyDerivative
(
"sqrt(x)"
,
"0.5/sqrt(x)"
);
verifyDerivative
(
"asin(x)"
,
"1/sqrt(1-x^2)"
);
verifyDerivative
(
"acos(x)"
,
"-1/sqrt(1-x^2)"
);
verifyDerivative
(
"atan(x)"
,
"1/(1+x^2)"
);
verifyDerivative
(
"recip(x)"
,
"-1/x^2"
);
verifyDerivative
(
"square(x)"
,
"2*x"
);
verifyDerivative
(
"cube(x)"
,
"3*x^2"
);
testCustomFunction
(
"custom(x, y)/2"
,
"x*y"
);
testCustomFunction
(
"custom(x^2, 1)+custom(2, y-1)"
,
"2*x^2+4*(y-1)"
);
map
<
string
,
double
>
variables
;
variables
[
"x"
]
=
2.0
;
variables
[
"y"
]
=
10.0
;
cout
<<
Parser
::
parse
(
"2*3*x"
).
optimize
()
<<
endl
;
cout
<<
Parser
::
parse
(
"1/(1+x)"
).
optimize
()
<<
endl
;
cout
<<
Parser
::
parse
(
"x^(1/2)"
).
optimize
()
<<
endl
;
cout
<<
Parser
::
parse
(
"log(3*cos(x))^(sqrt(4)-2)"
).
optimize
()
<<
endl
;
}
catch
(
const
exception
&
e
)
{
cout
<<
"exception: "
<<
e
.
what
()
<<
endl
;
return
1
;
}
cout
<<
"Done"
<<
endl
;
return
0
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment