Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
d89cd171
Commit
d89cd171
authored
Jul 12, 2014
by
kyleabeauchamp
Browse files
Merge remote-tracking branch 'upstream/master' into vagrant
parents
6b99ed69
a7466174
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
984 additions
and
542 deletions
+984
-542
CMakeLists.txt
CMakeLists.txt
+35
-15
libraries/lepton/src/CompiledExpression.cpp
libraries/lepton/src/CompiledExpression.cpp
+3
-3
libraries/lepton/src/ExpressionProgram.cpp
libraries/lepton/src/ExpressionProgram.cpp
+2
-2
libraries/lepton/src/ParsedExpression.cpp
libraries/lepton/src/ParsedExpression.cpp
+1
-1
libraries/lepton/src/Parser.cpp
libraries/lepton/src/Parser.cpp
+3
-3
openmmapi/include/openmm/internal/hardware.h
openmmapi/include/openmm/internal/hardware.h
+41
-31
openmmapi/include/openmm/internal/vectorize.h
openmmapi/include/openmm/internal/vectorize.h
+7
-252
openmmapi/include/openmm/internal/vectorize_neon.h
openmmapi/include/openmm/internal/vectorize_neon.h
+334
-0
openmmapi/include/openmm/internal/vectorize_sse.h
openmmapi/include/openmm/internal/vectorize_sse.h
+286
-0
openmmapi/src/OSRngSeed.cpp
openmmapi/src/OSRngSeed.cpp
+8
-6
platforms/cpu/sharedTarget/CMakeLists.txt
platforms/cpu/sharedTarget/CMakeLists.txt
+11
-7
platforms/cpu/src/CpuNeighborList.cpp
platforms/cpu/src/CpuNeighborList.cpp
+0
-1
platforms/cpu/src/CpuNonbondedForceVec4.cpp
platforms/cpu/src/CpuNonbondedForceVec4.cpp
+2
-2
platforms/cpu/src/CpuPlatform.cpp
platforms/cpu/src/CpuPlatform.cpp
+15
-9
platforms/opencl/include/OpenCLNonbondedUtilities.h
platforms/opencl/include/OpenCLNonbondedUtilities.h
+2
-1
platforms/opencl/src/OpenCLFFT3D.cpp
platforms/opencl/src/OpenCLFFT3D.cpp
+176
-165
platforms/opencl/src/OpenCLKernels.cpp
platforms/opencl/src/OpenCLKernels.cpp
+2
-2
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
+48
-35
platforms/opencl/src/OpenCLPlatform.cpp
platforms/opencl/src/OpenCLPlatform.cpp
+2
-4
platforms/opencl/src/OpenCLSort.cpp
platforms/opencl/src/OpenCLSort.cpp
+6
-3
No files found.
CMakeLists.txt
View file @
d89cd171
...
...
@@ -61,6 +61,11 @@ ELSE(WIN32)
ENDIF
(
NOT OPENMM_INSTALL_PREFIX
)
ENDIF
(
WIN32
)
# Include CPU-Features for Android
IF
(
ANDROID
)
INCLUDE_DIRECTORIES
(
${
ANDROID_NDK
}
/sources/cpufeatures
)
ENDIF
(
ANDROID
)
# It seems that on linux and mac, everything is trying to be installed in /usr/local/openmm
# But if every install target is prefixed with /openmm/, on Windows the install files
# end up in C:/Program Files/OpenMM/openmm/ which is ugly.
...
...
@@ -87,9 +92,11 @@ IF(WIN32)
SET
(
PTHREADS_LIB pthreadVC2
)
SET
(
PTHREADS_LIB_STATIC pthreadVC2_static_mt
)
ELSE
(
WIN32
)
SET
(
PTHREADS_LIB pthread
)
# in linux, even in static builds we link against the dynamic object (since its tied to libc versions)
SET
(
PTHREADS_LIB_STATIC pthread
)
IF
(
NOT ANDROID
)
SET
(
PTHREADS_LIB pthread
)
# in linux, even in static builds we link against the dynamic object (since its tied to libc versions)
SET
(
PTHREADS_LIB_STATIC pthread
)
ENDIF
(
NOT ANDROID
)
ENDIF
(
WIN32
)
# The build system will set ARCH64 for 64 bit builds, which require
...
...
@@ -121,11 +128,11 @@ IF (APPLE)
SET
(
CMAKE_INSTALL_NAME_DIR
"@rpath"
)
SET
(
EXTRA_COMPILE_FLAGS
"-msse2 -stdlib=libc++"
)
ELSE
(
APPLE
)
IF
(
MSVC
)
IF
(
MSVC
OR ANDROID
)
SET
(
EXTRA_COMPILE_FLAGS
)
ELSE
(
MSVC
)
ELSE
(
MSVC
OR ANDROID
)
SET
(
EXTRA_COMPILE_FLAGS
"-msse2"
)
ENDIF
(
MSVC
)
ENDIF
(
MSVC
OR ANDROID
)
ENDIF
(
APPLE
)
IF
(
UNIX AND NOT CMAKE_BUILD_TYPE
)
...
...
@@ -137,8 +144,13 @@ IF (NOT CMAKE_CXX_FLAGS_DEBUG)
ENDIF
(
NOT CMAKE_CXX_FLAGS_DEBUG
)
IF
(
NOT CMAKE_CXX_FLAGS_RELEASE
)
SET
(
CMAKE_CXX_FLAGS_RELEASE
"-O3 -DNDEBUG"
CACHE STRING
"To use when CMAKE_BUILD_TYPE=Release"
FORCE
)
IF
(
ANDROID
)
SET
(
CMAKE_CXX_FLAGS_RELEASE
"-mfloat-abi=softfp -march=armv7-a -mfpu=neon -funsafe-math-optimizations -O3 -DNDEBUG"
CACHE STRING
"To use when CMAKE_BUILD_TYPE=Release"
FORCE
)
ELSE
(
ANDROID
)
SET
(
CMAKE_CXX_FLAGS_RELEASE
"-O3 -DNDEBUG"
CACHE STRING
"To use when CMAKE_BUILD_TYPE=Release"
FORCE
)
ENDIF
(
ANDROID
)
ENDIF
(
NOT CMAKE_CXX_FLAGS_RELEASE
)
...
...
@@ -252,7 +264,11 @@ FOREACH(subdir ${OPENMM_SOURCE_SUBDIRS})
## OpenMM was previously installed there.
INCLUDE_DIRECTORIES
(
BEFORE
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
subdir
}
/include
)
ENDFOREACH
(
subdir
)
SET_SOURCE_FILES_PROPERTIES
(
${
CMAKE_SOURCE_DIR
}
/libraries/sfmt/src/SFMT.cpp PROPERTIES COMPILE_FLAGS
"-DHAVE_SSE2=1"
)
IF
(
ANDROID
)
SET_SOURCE_FILES_PROPERTIES
(
${
CMAKE_SOURCE_DIR
}
/libraries/sfmt/src/SFMT.cpp PROPERTIES COMPILE_FLAGS
"-UHAVE_SSE2"
)
ELSE
(
ANDROID
)
SET_SOURCE_FILES_PROPERTIES
(
${
CMAKE_SOURCE_DIR
}
/libraries/sfmt/src/SFMT.cpp PROPERTIES COMPILE_FLAGS
"-DHAVE_SSE2=1"
)
ENDIF
(
ANDROID
)
# If API wrappers are being generated, and add them to the build.
SET
(
OPENMM_BUILD_C_AND_FORTRAN_WRAPPERS ON CACHE BOOL
"Build wrappers for C and Fortran"
)
...
...
@@ -287,13 +303,17 @@ ENDIF(OPENMM_BUILD_C_AND_FORTRAN_WRAPPERS)
# On Linux need to link to libdl
FIND_LIBRARY
(
DL_LIBRARY dl
)
IF
(
DL_LIBRARY
)
TARGET_LINK_LIBRARIES
(
${
SHARED_TARGET
}
${
DL_LIBRARY
}
${
PTHREADS_LIB
}
)
IF
(
OPENMM_BUILD_STATIC_LIB
)
TARGET_LINK_LIBRARIES
(
${
STATIC_TARGET
}
${
DL_LIBRARY
}
${
PTHREADS_LIB
}
)
ENDIF
(
OPENMM_BUILD_STATIC_LIB
)
MARK_AS_ADVANCED
(
DL_LIBRARY
)
TARGET_LINK_LIBRARIES
(
${
SHARED_TARGET
}
${
DL_LIBRARY
}
${
PTHREADS_LIB
}
)
IF
(
OPENMM_BUILD_STATIC_LIB
)
TARGET_LINK_LIBRARIES
(
${
STATIC_TARGET
}
${
DL_LIBRARY
}
${
PTHREADS_LIB
}
)
ENDIF
(
OPENMM_BUILD_STATIC_LIB
)
MARK_AS_ADVANCED
(
DL_LIBRARY
)
ELSE
(
DL_LIBRARY
)
TARGET_LINK_LIBRARIES
(
${
SHARED_TARGET
}
${
PTHREADS_LIB
}
)
IF
(
ANDROID
)
TARGET_LINK_LIBRARIES
(
${
SHARED_TARGET
}
${
PTHREADS_LIB
}
cpufeatures
)
ELSE
(
ANDROID
)
TARGET_LINK_LIBRARIES
(
${
SHARED_TARGET
}
${
PTHREADS_LIB
}
)
ENDIF
(
ANDROID
)
ENDIF
(
DL_LIBRARY
)
ADD_SUBDIRECTORY
(
platforms/reference/tests
)
...
...
libraries/lepton/src/CompiledExpression.cpp
View file @
d89cd171
...
...
@@ -84,13 +84,13 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto
// Process this node.
if
(
node
.
getOperation
().
getId
()
==
Operation
::
VARIABLE
)
{
variableIndices
[
node
.
getOperation
().
getName
()]
=
workspace
.
size
();
variableIndices
[
node
.
getOperation
().
getName
()]
=
(
int
)
workspace
.
size
();
variableNames
.
insert
(
node
.
getOperation
().
getName
());
}
else
{
int
stepIndex
=
arguments
.
size
();
int
stepIndex
=
(
int
)
arguments
.
size
();
arguments
.
push_back
(
vector
<
int
>
());
target
.
push_back
(
workspace
.
size
());
target
.
push_back
(
(
int
)
workspace
.
size
());
operation
.
push_back
(
node
.
getOperation
().
clone
());
if
(
args
.
size
()
==
0
)
arguments
[
stepIndex
].
push_back
(
0
);
// The value won't actually be used. We just need something there.
...
...
libraries/lepton/src/ExpressionProgram.cpp
View file @
d89cd171
...
...
@@ -71,13 +71,13 @@ ExpressionProgram& ExpressionProgram::operator=(const ExpressionProgram& program
}
void
ExpressionProgram
::
buildProgram
(
const
ExpressionTreeNode
&
node
)
{
for
(
int
i
=
node
.
getChildren
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
int
i
=
(
int
)
node
.
getChildren
().
size
()
-
1
;
i
>=
0
;
i
--
)
buildProgram
(
node
.
getChildren
()[
i
]);
operations
.
push_back
(
node
.
getOperation
().
clone
());
}
int
ExpressionProgram
::
getNumOperations
()
const
{
return
operations
.
size
();
return
(
int
)
operations
.
size
();
}
const
Operation
&
ExpressionProgram
::
getOperation
(
int
index
)
const
{
...
...
libraries/lepton/src/ParsedExpression.cpp
View file @
d89cd171
...
...
@@ -60,7 +60,7 @@ double ParsedExpression::evaluate(const map<string, double>& variables) const {
}
double
ParsedExpression
::
evaluate
(
const
ExpressionTreeNode
&
node
,
const
map
<
string
,
double
>&
variables
)
{
int
numArgs
=
node
.
getChildren
().
size
();
int
numArgs
=
(
int
)
node
.
getChildren
().
size
();
vector
<
double
>
args
(
max
(
numArgs
,
1
));
for
(
int
i
=
0
;
i
<
numArgs
;
i
++
)
args
[
i
]
=
evaluate
(
node
.
getChildren
()[
i
],
variables
);
...
...
libraries/lepton/src/Parser.cpp
View file @
d89cd171
...
...
@@ -70,7 +70,7 @@ string Parser::trim(const string& expression) {
int
start
,
end
;
for
(
start
=
0
;
start
<
(
int
)
expression
.
size
()
&&
isspace
(
expression
[
start
]);
start
++
)
;
for
(
end
=
expression
.
size
()
-
1
;
end
>
start
&&
isspace
(
expression
[
end
]);
end
--
)
for
(
end
=
(
int
)
expression
.
size
()
-
1
;
end
>
start
&&
isspace
(
expression
[
end
]);
end
--
)
;
if
(
start
==
end
&&
isspace
(
expression
[
end
]))
return
""
;
...
...
@@ -140,7 +140,7 @@ vector<ParseToken> Parser::tokenize(const string& expression) {
ParseToken
token
=
getNextToken
(
expression
,
pos
);
if
(
token
.
getType
()
!=
ParseToken
::
Whitespace
)
tokens
.
push_back
(
token
);
pos
+=
token
.
getText
().
size
();
pos
+=
(
int
)
token
.
getText
().
size
();
}
return
tokens
;
}
...
...
@@ -257,7 +257,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector<ParseToken>& tokens, int
while
(
pos
<
(
int
)
tokens
.
size
()
&&
tokens
[
pos
].
getType
()
==
ParseToken
::
Operator
)
{
token
=
tokens
[
pos
];
int
opIndex
=
Operators
.
find
(
token
.
getText
());
int
opIndex
=
(
int
)
Operators
.
find
(
token
.
getText
());
int
opPrecedence
=
Precedence
[
opIndex
];
if
(
opPrecedence
<
precedence
)
return
result
;
...
...
openmmapi/include/openmm/internal/hardware.h
View file @
d89cd171
...
...
@@ -47,8 +47,12 @@
#define NOMINMAX
#include <windows.h>
#else
#include <dlfcn.h>
#include <unistd.h>
#ifdef __ANDROID__
#include <cpu-features.h>
#else
#include <dlfcn.h>
#include <unistd.h>
#endif
#endif
#endif
...
...
@@ -70,11 +74,15 @@ static int getNumProcessors() {
ncpu
=
1
;
return
ncpu
;
#else
long
nProcessorsOnline
=
sysconf
(
_SC_NPROCESSORS_ONLN
);
if
(
nProcessorsOnline
==
-
1
)
return
1
;
else
return
(
int
)
nProcessorsOnline
;
#ifdef __ANDROID__
return
android_getCpuCount
();
#else
long
nProcessorsOnline
=
sysconf
(
_SC_NPROCESSORS_ONLN
);
if
(
nProcessorsOnline
==
-
1
)
return
1
;
else
return
(
int
)
nProcessorsOnline
;
#endif
#endif
#endif
}
...
...
@@ -85,30 +93,32 @@ static int getNumProcessors() {
#ifdef _WIN32
#define cpuid __cpuid
#else
static
void
cpuid
(
int
cpuInfo
[
4
],
int
infoType
){
#ifdef __LP64__
__asm__
__volatile__
(
"cpuid"
:
"=a"
(
cpuInfo
[
0
]),
"=b"
(
cpuInfo
[
1
]),
"=c"
(
cpuInfo
[
2
]),
"=d"
(
cpuInfo
[
3
])
:
"a"
(
infoType
)
);
#else
__asm__
__volatile__
(
"pushl %%ebx
\n
"
"cpuid
\n
"
"movl %%ebx, %1
\n
"
"popl %%ebx
\n
"
:
"=a"
(
cpuInfo
[
0
]),
"=r"
(
cpuInfo
[
1
]),
"=c"
(
cpuInfo
[
2
]),
"=d"
(
cpuInfo
[
3
])
:
"a"
(
infoType
)
);
#endif
}
#ifndef __ANDROID__
static
void
cpuid
(
int
cpuInfo
[
4
],
int
infoType
){
#ifdef __LP64__
__asm__
__volatile__
(
"cpuid"
:
"=a"
(
cpuInfo
[
0
]),
"=b"
(
cpuInfo
[
1
]),
"=c"
(
cpuInfo
[
2
]),
"=d"
(
cpuInfo
[
3
])
:
"a"
(
infoType
)
);
#else
__asm__
__volatile__
(
"pushl %%ebx
\n
"
"cpuid
\n
"
"movl %%ebx, %1
\n
"
"popl %%ebx
\n
"
:
"=a"
(
cpuInfo
[
0
]),
"=r"
(
cpuInfo
[
1
]),
"=c"
(
cpuInfo
[
2
]),
"=d"
(
cpuInfo
[
3
])
:
"a"
(
infoType
)
);
#endif
}
#endif
#endif
#endif // OPENMM_HARDWARE_H_
openmmapi/include/openmm/internal/vectorize.h
View file @
d89cd171
...
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 201
3
Stanford University and the Authors. *
* Portions copyright (c) 201
4
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -31,256 +31,11 @@
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include <smmintrin.h>
// This file defines classes and functions to simplify vectorizing code with SSE.
class
ivec4
;
/**
* A four element vector of floats.
*/
class
fvec4
{
public:
__m128
val
;
fvec4
()
{}
fvec4
(
float
v
)
:
val
(
_mm_set1_ps
(
v
))
{}
fvec4
(
float
v1
,
float
v2
,
float
v3
,
float
v4
)
:
val
(
_mm_set_ps
(
v4
,
v3
,
v2
,
v1
))
{}
fvec4
(
__m128
v
)
:
val
(
v
)
{}
fvec4
(
const
float
*
v
)
:
val
(
_mm_loadu_ps
(
v
))
{}
operator
__m128
()
const
{
return
val
;
}
float
operator
[](
int
i
)
const
{
float
result
[
4
];
store
(
result
);
return
result
[
i
];
}
void
store
(
float
*
v
)
const
{
_mm_storeu_ps
(
v
,
val
);
}
fvec4
operator
+
(
const
fvec4
&
other
)
const
{
return
_mm_add_ps
(
val
,
other
);
}
fvec4
operator
-
(
const
fvec4
&
other
)
const
{
return
_mm_sub_ps
(
val
,
other
);
}
fvec4
operator
*
(
const
fvec4
&
other
)
const
{
return
_mm_mul_ps
(
val
,
other
);
}
fvec4
operator
/
(
const
fvec4
&
other
)
const
{
return
_mm_div_ps
(
val
,
other
);
}
void
operator
+=
(
const
fvec4
&
other
)
{
val
=
_mm_add_ps
(
val
,
other
);
}
void
operator
-=
(
const
fvec4
&
other
)
{
val
=
_mm_sub_ps
(
val
,
other
);
}
void
operator
*=
(
const
fvec4
&
other
)
{
val
=
_mm_mul_ps
(
val
,
other
);
}
void
operator
/=
(
const
fvec4
&
other
)
{
val
=
_mm_div_ps
(
val
,
other
);
}
fvec4
operator
-
()
const
{
return
_mm_sub_ps
(
_mm_set1_ps
(
0.0
f
),
val
);
}
fvec4
operator
&
(
const
fvec4
&
other
)
const
{
return
_mm_and_ps
(
val
,
other
);
}
fvec4
operator
|
(
const
fvec4
&
other
)
const
{
return
_mm_or_ps
(
val
,
other
);
}
fvec4
operator
==
(
const
fvec4
&
other
)
const
{
return
_mm_cmpeq_ps
(
val
,
other
);
}
fvec4
operator
!=
(
const
fvec4
&
other
)
const
{
return
_mm_cmpneq_ps
(
val
,
other
);
}
fvec4
operator
>
(
const
fvec4
&
other
)
const
{
return
_mm_cmpgt_ps
(
val
,
other
);
}
fvec4
operator
<
(
const
fvec4
&
other
)
const
{
return
_mm_cmplt_ps
(
val
,
other
);
}
fvec4
operator
>=
(
const
fvec4
&
other
)
const
{
return
_mm_cmpge_ps
(
val
,
other
);
}
fvec4
operator
<=
(
const
fvec4
&
other
)
const
{
return
_mm_cmple_ps
(
val
,
other
);
}
operator
ivec4
()
const
;
};
/**
* A four element vector of ints.
*/
class
ivec4
{
public:
__m128i
val
;
ivec4
()
{}
ivec4
(
int
v
)
:
val
(
_mm_set1_epi32
(
v
))
{}
ivec4
(
int
v1
,
int
v2
,
int
v3
,
int
v4
)
:
val
(
_mm_set_epi32
(
v4
,
v3
,
v2
,
v1
))
{}
ivec4
(
__m128i
v
)
:
val
(
v
)
{}
ivec4
(
const
int
*
v
)
:
val
(
_mm_loadu_si128
((
const
__m128i
*
)
v
))
{}
operator
__m128i
()
const
{
return
val
;
}
int
operator
[](
int
i
)
const
{
int
result
[
4
];
store
(
result
);
return
result
[
i
];
}
void
store
(
int
*
v
)
const
{
_mm_storeu_si128
((
__m128i
*
)
v
,
val
);
}
ivec4
operator
+
(
const
ivec4
&
other
)
const
{
return
_mm_add_epi32
(
val
,
other
);
}
ivec4
operator
-
(
const
ivec4
&
other
)
const
{
return
_mm_sub_epi32
(
val
,
other
);
}
ivec4
operator
*
(
const
ivec4
&
other
)
const
{
return
_mm_mul_epi32
(
val
,
other
);
}
void
operator
+=
(
const
ivec4
&
other
)
{
val
=
_mm_add_epi32
(
val
,
other
);
}
void
operator
-=
(
const
ivec4
&
other
)
{
val
=
_mm_sub_epi32
(
val
,
other
);
}
void
operator
*=
(
const
ivec4
&
other
)
{
val
=
_mm_mul_epi32
(
val
,
other
);
}
ivec4
operator
-
()
const
{
return
_mm_sub_epi32
(
_mm_set1_epi32
(
0
),
val
);
}
ivec4
operator
&
(
const
ivec4
&
other
)
const
{
return
_mm_and_si128
(
val
,
other
);
}
ivec4
operator
|
(
const
ivec4
&
other
)
const
{
return
_mm_or_si128
(
val
,
other
);
}
ivec4
operator
==
(
const
ivec4
&
other
)
const
{
return
_mm_cmpeq_epi32
(
val
,
other
);
}
ivec4
operator
!=
(
const
ivec4
&
other
)
const
{
return
_mm_xor_si128
(
*
this
==
other
,
_mm_set1_epi32
(
0xFFFFFFFF
));
}
ivec4
operator
>
(
const
ivec4
&
other
)
const
{
return
_mm_cmpgt_epi32
(
val
,
other
);
}
ivec4
operator
<
(
const
ivec4
&
other
)
const
{
return
_mm_cmplt_epi32
(
val
,
other
);
}
ivec4
operator
>=
(
const
ivec4
&
other
)
const
{
return
_mm_xor_si128
(
_mm_cmplt_epi32
(
val
,
other
),
_mm_set1_epi32
(
0xFFFFFFFF
));
}
ivec4
operator
<=
(
const
ivec4
&
other
)
const
{
return
_mm_xor_si128
(
_mm_cmpgt_epi32
(
val
,
other
),
_mm_set1_epi32
(
0xFFFFFFFF
));
}
operator
fvec4
()
const
;
};
// Conversion operators.
inline
fvec4
::
operator
ivec4
()
const
{
return
_mm_cvttps_epi32
(
val
);
}
inline
ivec4
::
operator
fvec4
()
const
{
return
_mm_cvtepi32_ps
(
val
);
}
// Functions that operate on fvec4s.
static
inline
fvec4
floor
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_floor_ps
(
v
.
val
));
}
static
inline
fvec4
ceil
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_ceil_ps
(
v
.
val
));
}
static
inline
fvec4
round
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_round_ps
(
v
.
val
,
_MM_FROUND_TO_NEAREST_INT
));
}
static
inline
fvec4
min
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
_mm_min_ps
(
v1
.
val
,
v2
.
val
));
}
static
inline
fvec4
max
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
_mm_max_ps
(
v1
.
val
,
v2
.
val
));
}
static
inline
fvec4
abs
(
const
fvec4
&
v
)
{
static
const
__m128
mask
=
_mm_castsi128_ps
(
_mm_set1_epi32
(
0x7FFFFFFF
));
return
fvec4
(
_mm_and_ps
(
v
.
val
,
mask
));
}
static
inline
fvec4
sqrt
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_sqrt_ps
(
v
.
val
));
}
static
inline
float
dot3
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
_mm_cvtss_f32
(
_mm_dp_ps
(
v1
,
v2
,
0x71
));
}
static
inline
float
dot4
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
_mm_cvtss_f32
(
_mm_dp_ps
(
v1
,
v2
,
0xF1
));
}
static
inline
void
transpose
(
fvec4
&
v1
,
fvec4
&
v2
,
fvec4
&
v3
,
fvec4
&
v4
)
{
_MM_TRANSPOSE4_PS
(
v1
,
v2
,
v3
,
v4
);
}
// Functions that operate on ivec4s.
static
inline
ivec4
min
(
const
ivec4
&
v1
,
const
ivec4
&
v2
)
{
return
ivec4
(
_mm_min_epi32
(
v1
.
val
,
v2
.
val
));
}
static
inline
ivec4
max
(
const
ivec4
&
v1
,
const
ivec4
&
v2
)
{
return
ivec4
(
_mm_max_epi32
(
v1
.
val
,
v2
.
val
));
}
static
inline
ivec4
abs
(
const
ivec4
&
v
)
{
return
ivec4
(
_mm_abs_epi32
(
v
.
val
));
}
static
inline
bool
any
(
const
ivec4
&
v
)
{
return
!
_mm_test_all_zeros
(
v
,
_mm_set1_epi32
(
0xFFFFFFFF
));
}
// Mathematical operators involving a scalar and a vector.
static
inline
fvec4
operator
+
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
+
v2
;
}
static
inline
fvec4
operator
-
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
-
v2
;
}
static
inline
fvec4
operator
*
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
*
v2
;
}
static
inline
fvec4
operator
/
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
/
v2
;
}
// Operations for blending fvec4s based on an ivec4.
static
inline
fvec4
blend
(
const
fvec4
&
v1
,
const
fvec4
&
v2
,
const
ivec4
&
mask
)
{
return
fvec4
(
_mm_blendv_ps
(
v1
.
val
,
v2
.
val
,
_mm_castsi128_ps
(
mask
.
val
)));
}
#if defined(__ANDROID__)
#include "vectorize_neon.h"
#else
#include "vectorize_sse.h"
#endif
#endif
/*OPENMM_VECTORIZE_H_*/
openmmapi/include/openmm/internal/vectorize_neon.h
0 → 100644
View file @
d89cd171
#ifndef OPENMM_VECTORIZE_NEON_H_
#define OPENMM_VECTORIZE_NEON_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2014 Stanford University and the Authors. *
* Authors: Mateus Lima, Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include <cpu-features.h>
#include <arm_neon.h>
#include <cmath>
typedef
int
int32_t
;
// This file defines classes and functions to simplify vectorizing code with NEON.
class
ivec4
;
/**
* A four element vector of floats.
*/
class
fvec4
{
public:
float32x4_t
val
;
fvec4
()
{}
fvec4
(
float
v
)
:
val
(
vdupq_n_f32
(
v
))
{}
fvec4
(
float
v1
,
float
v2
,
float
v3
,
float
v4
)
{
float
v
[]
=
{
v1
,
v2
,
v3
,
v4
};
val
=
vld1q_f32
(
v
);
}
fvec4
(
float32x4_t
v
)
:
val
(
v
)
{}
fvec4
(
const
float
*
v
)
:
val
(
vld1q_f32
(
v
))
{}
operator
float32x4_t
()
const
{
return
val
;
}
float
operator
[](
int
i
)
const
{
switch
(
i
)
{
case
0
:
return
vgetq_lane_f32
(
val
,
0
);
case
1
:
return
vgetq_lane_f32
(
val
,
1
);
case
2
:
return
vgetq_lane_f32
(
val
,
2
);
case
3
:
return
vgetq_lane_f32
(
val
,
3
);
}
return
0.0
f
;
}
void
store
(
float
*
v
)
const
{
vst1q_f32
(
v
,
val
);
}
fvec4
operator
+
(
const
fvec4
&
other
)
const
{
return
vaddq_f32
(
val
,
other
);
}
fvec4
operator
-
(
const
fvec4
&
other
)
const
{
return
vsubq_f32
(
val
,
other
);
}
fvec4
operator
*
(
const
fvec4
&
other
)
const
{
return
vmulq_f32
(
val
,
other
);
}
fvec4
operator
/
(
const
fvec4
&
other
)
const
{
// NEON does not have a divide float-point operator, so we get the reciprocal and multiply.
float32x4_t
reciprocal
=
vrecpeq_f32
(
other
);
reciprocal
=
vmulq_f32
(
vrecpsq_f32
(
other
,
reciprocal
),
reciprocal
);
reciprocal
=
vmulq_f32
(
vrecpsq_f32
(
other
,
reciprocal
),
reciprocal
);
fvec4
result
=
vmulq_f32
(
val
,
reciprocal
);
return
result
;
}
void
operator
+=
(
const
fvec4
&
other
)
{
val
=
vaddq_f32
(
val
,
other
);
}
void
operator
-=
(
const
fvec4
&
other
)
{
val
=
vsubq_f32
(
val
,
other
);
}
void
operator
*=
(
const
fvec4
&
other
)
{
val
=
vmulq_f32
(
val
,
other
);
}
void
operator
/=
(
const
fvec4
&
other
)
{
val
=
*
this
/
other
;
}
fvec4
operator
-
()
const
{
return
vnegq_f32
(
val
);
}
fvec4
operator
&
(
const
fvec4
&
other
)
const
{
return
vreinterpretq_f32_u32
(
vandq_u32
(
vreinterpretq_u32_f32
(
val
),
vreinterpretq_u32_f32
(
other
)));
}
fvec4
operator
|
(
const
fvec4
&
other
)
const
{
return
vreinterpretq_f32_u32
(
vorrq_u32
(
vreinterpretq_u32_f32
(
val
),
vreinterpretq_u32_f32
(
other
)));
}
fvec4
operator
==
(
const
fvec4
&
other
)
const
{
return
vcvtq_f32_s32
(
vreinterpretq_s32_u32
(
vceqq_f32
(
val
,
other
)));
}
fvec4
operator
!=
(
const
fvec4
&
other
)
const
{
return
vcvtq_f32_s32
(
vreinterpretq_s32_u32
(
vmvnq_u32
(
vceqq_f32
(
val
,
other
))));
// not(equals(val, other))
}
fvec4
operator
>
(
const
fvec4
&
other
)
const
{
return
vcvtq_f32_s32
(
vreinterpretq_s32_u32
(
vcgtq_f32
(
val
,
other
)));
}
fvec4
operator
<
(
const
fvec4
&
other
)
const
{
return
vcvtq_f32_s32
(
vreinterpretq_s32_u32
(
vcltq_f32
(
val
,
other
)));
}
fvec4
operator
>=
(
const
fvec4
&
other
)
const
{
return
vcvtq_f32_s32
(
vreinterpretq_s32_u32
(
vcgeq_f32
(
val
,
other
)));
}
fvec4
operator
<=
(
const
fvec4
&
other
)
const
{
return
vcvtq_f32_s32
(
vreinterpretq_s32_u32
(
vcleq_f32
(
val
,
other
)));
}
operator
ivec4
()
const
;
};
/**
* A four element vector of ints.
*/
class
ivec4
{
public:
int32x4_t
val
;
ivec4
()
{}
ivec4
(
int
v
)
:
val
(
vdupq_n_s32
(
v
))
{}
ivec4
(
int
v1
,
int
v2
,
int
v3
,
int
v4
)
{
int
v
[]
=
{
v1
,
v2
,
v3
,
v4
};
val
=
vld1q_s32
(
v
);
}
ivec4
(
int32x4_t
v
)
:
val
(
v
)
{}
ivec4
(
const
int
*
v
)
:
val
(
vld1q_s32
(
v
))
{}
operator
int32x4_t
()
const
{
return
val
;
}
int
operator
[](
int
i
)
const
{
switch
(
i
)
{
case
0
:
return
vgetq_lane_s32
(
val
,
0
);
case
1
:
return
vgetq_lane_s32
(
val
,
1
);
case
2
:
return
vgetq_lane_s32
(
val
,
2
);
case
3
:
return
vgetq_lane_s32
(
val
,
3
);
}
return
0
;
}
void
store
(
int
*
v
)
const
{
vst1q_s32
(
v
,
val
);
}
ivec4
operator
+
(
const
ivec4
&
other
)
const
{
return
vaddq_s32
(
val
,
other
);
}
ivec4
operator
-
(
const
ivec4
&
other
)
const
{
return
vsubq_s32
(
val
,
other
);
}
ivec4
operator
*
(
const
ivec4
&
other
)
const
{
return
vmulq_s32
(
val
,
other
);
}
void
operator
+=
(
const
ivec4
&
other
)
{
val
=
vaddq_s32
(
val
,
other
);
}
void
operator
-=
(
const
ivec4
&
other
)
{
val
=
vsubq_s32
(
val
,
other
);
}
void
operator
*=
(
const
ivec4
&
other
)
{
val
=
vmulq_s32
(
val
,
other
);
}
ivec4
operator
-
()
const
{
return
vnegq_s32
(
val
);
}
ivec4
operator
&
(
const
ivec4
&
other
)
const
{
return
vandq_s32
(
val
,
other
);
}
ivec4
operator
|
(
const
ivec4
&
other
)
const
{
return
vorrq_s32
(
val
,
other
);
}
ivec4
operator
==
(
const
ivec4
&
other
)
const
{
return
vreinterpretq_s32_u32
(
vceqq_s32
(
val
,
other
));
}
ivec4
operator
!=
(
const
ivec4
&
other
)
const
{
return
vreinterpretq_s32_u32
(
vmvnq_u32
(
vceqq_s32
(
val
,
other
)));
// not(equal(val, other))
}
ivec4
operator
>
(
const
ivec4
&
other
)
const
{
return
vreinterpretq_s32_u32
(
vcgtq_s32
(
val
,
other
));
}
ivec4
operator
<
(
const
ivec4
&
other
)
const
{
return
vreinterpretq_s32_u32
(
vcltq_s32
(
val
,
other
));
}
ivec4
operator
>=
(
const
ivec4
&
other
)
const
{
return
vreinterpretq_s32_u32
(
vcgeq_s32
(
val
,
other
));
}
ivec4
operator
<=
(
const
ivec4
&
other
)
const
{
return
vreinterpretq_s32_u32
(
vcleq_s32
(
val
,
other
));
}
operator
fvec4
()
const
;
};
// Conversion operators.
inline
fvec4
::
operator
ivec4
()
const
{
return
ivec4
(
vcvtq_s32_f32
(
val
));
}
inline
ivec4
::
operator
fvec4
()
const
{
return
fvec4
(
vcvtq_f32_s32
(
val
));
}
// Functions that operate on fvec4s.
static
inline
fvec4
min
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
vminq_f32
(
v1
,
v2
);
}
static
inline
fvec4
max
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
vmaxq_f32
(
v1
,
v2
);
}
static
inline
fvec4
abs
(
const
fvec4
&
v
)
{
return
vabsq_f32
(
v
);
}
static
inline
fvec4
sqrt
(
const
fvec4
&
v
)
{
float32x4_t
recipSqrt
=
vrsqrteq_f32
(
v
);
recipSqrt
=
vmulq_f32
(
recipSqrt
,
vrsqrtsq_f32
(
vmulq_f32
(
recipSqrt
,
v
),
recipSqrt
));
recipSqrt
=
vmulq_f32
(
recipSqrt
,
vrsqrtsq_f32
(
vmulq_f32
(
recipSqrt
,
v
),
recipSqrt
));
return
vmulq_f32
(
v
,
recipSqrt
);
}
static
inline
float
dot3
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
fvec4
result
=
v1
*
v2
;
return
vgetq_lane_f32
(
result
,
0
)
+
vgetq_lane_f32
(
result
,
1
)
+
vgetq_lane_f32
(
result
,
2
);
}
static
inline
float
dot4
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
fvec4
result
=
v1
*
v2
;
return
vgetq_lane_f32
(
result
,
0
)
+
vgetq_lane_f32
(
result
,
1
)
+
vgetq_lane_f32
(
result
,
2
)
+
vgetq_lane_f32
(
result
,
3
);
}
static
inline
void
transpose
(
fvec4
&
v1
,
fvec4
&
v2
,
fvec4
&
v3
,
fvec4
&
v4
)
{
float32x4x2_t
t1
=
vuzpq_f32
(
v1
,
v3
);
float32x4x2_t
t2
=
vuzpq_f32
(
v2
,
v4
);
float32x4x2_t
t3
=
vtrnq_f32
(
t1
.
val
[
0
],
t2
.
val
[
0
]);
float32x4x2_t
t4
=
vtrnq_f32
(
t1
.
val
[
1
],
t2
.
val
[
1
]);
v1
=
t3
.
val
[
0
];
v2
=
t4
.
val
[
0
];
v3
=
t3
.
val
[
1
];
v4
=
t4
.
val
[
1
];
}
// Functions that operate on ivec4s.
static
inline
ivec4
min
(
const
ivec4
&
v1
,
const
ivec4
&
v2
)
{
return
vminq_s32
(
v1
,
v2
);
}
static
inline
ivec4
max
(
const
ivec4
&
v1
,
const
ivec4
&
v2
)
{
return
vmaxq_s32
(
v1
,
v2
);
}
static
inline
ivec4
abs
(
const
ivec4
&
v
)
{
return
vabdq_s32
(
v
,
ivec4
(
0
));
}
static
inline
bool
any
(
const
ivec4
&
v
)
{
return
(
vgetq_lane_s32
(
v
,
0
)
!=
0
||
vgetq_lane_s32
(
v
,
1
)
!=
0
||
vgetq_lane_s32
(
v
,
2
)
!=
0
||
vgetq_lane_s32
(
v
,
3
)
!=
0
);
}
// Mathematical operators involving a scalar and a vector.
static
inline
fvec4
operator
+
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
+
v2
;
}
static
inline
fvec4
operator
-
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
-
v2
;
}
static
inline
fvec4
operator
*
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
*
v2
;
}
static
inline
fvec4
operator
/
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
/
v2
;
}
// Operations for blending fvec4s based on an ivec4.
static
inline
fvec4
blend
(
const
fvec4
&
v1
,
const
fvec4
&
v2
,
const
ivec4
&
mask
)
{
return
vbslq_f32
(
vreinterpretq_u32_s32
(
mask
),
v2
,
v1
);
}
// These are at the end since they involve other functions defined above.
static
inline
fvec4
round
(
const
fvec4
&
v
)
{
fvec4
shift
(
0x1
.0
p23f
);
fvec4
absResult
=
(
abs
(
v
)
+
shift
)
-
shift
;
return
blend
(
v
,
absResult
,
ivec4
(
0x7FFFFFFF
));
}
static
inline
fvec4
floor
(
const
fvec4
&
v
)
{
fvec4
rounded
=
round
(
v
);
return
rounded
+
blend
(
0.0
f
,
-
1.0
f
,
rounded
>
v
);
}
static
inline
fvec4
ceil
(
const
fvec4
&
v
)
{
fvec4
rounded
=
round
(
v
);
return
rounded
+
blend
(
0.0
f
,
1.0
f
,
rounded
<
v
);
}
#endif
/*OPENMM_VECTORIZE_NEON_H_*/
openmmapi/include/openmm/internal/vectorize_sse.h
0 → 100644
View file @
d89cd171
#ifndef OPENMM_VECTORIZE_SSE_H_
#define OPENMM_VECTORIZE_SSE_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include <smmintrin.h>
// This file defines classes and functions to simplify vectorizing code with SSE.
class
ivec4
;
/**
* A four element vector of floats.
*/
class
fvec4
{
public:
__m128
val
;
fvec4
()
{}
fvec4
(
float
v
)
:
val
(
_mm_set1_ps
(
v
))
{}
fvec4
(
float
v1
,
float
v2
,
float
v3
,
float
v4
)
:
val
(
_mm_set_ps
(
v4
,
v3
,
v2
,
v1
))
{}
fvec4
(
__m128
v
)
:
val
(
v
)
{}
fvec4
(
const
float
*
v
)
:
val
(
_mm_loadu_ps
(
v
))
{}
operator
__m128
()
const
{
return
val
;
}
float
operator
[](
int
i
)
const
{
float
result
[
4
];
store
(
result
);
return
result
[
i
];
}
void
store
(
float
*
v
)
const
{
_mm_storeu_ps
(
v
,
val
);
}
fvec4
operator
+
(
const
fvec4
&
other
)
const
{
return
_mm_add_ps
(
val
,
other
);
}
fvec4
operator
-
(
const
fvec4
&
other
)
const
{
return
_mm_sub_ps
(
val
,
other
);
}
fvec4
operator
*
(
const
fvec4
&
other
)
const
{
return
_mm_mul_ps
(
val
,
other
);
}
fvec4
operator
/
(
const
fvec4
&
other
)
const
{
return
_mm_div_ps
(
val
,
other
);
}
void
operator
+=
(
const
fvec4
&
other
)
{
val
=
_mm_add_ps
(
val
,
other
);
}
void
operator
-=
(
const
fvec4
&
other
)
{
val
=
_mm_sub_ps
(
val
,
other
);
}
void
operator
*=
(
const
fvec4
&
other
)
{
val
=
_mm_mul_ps
(
val
,
other
);
}
void
operator
/=
(
const
fvec4
&
other
)
{
val
=
_mm_div_ps
(
val
,
other
);
}
fvec4
operator
-
()
const
{
return
_mm_sub_ps
(
_mm_set1_ps
(
0.0
f
),
val
);
}
fvec4
operator
&
(
const
fvec4
&
other
)
const
{
return
_mm_and_ps
(
val
,
other
);
}
fvec4
operator
|
(
const
fvec4
&
other
)
const
{
return
_mm_or_ps
(
val
,
other
);
}
fvec4
operator
==
(
const
fvec4
&
other
)
const
{
return
_mm_cmpeq_ps
(
val
,
other
);
}
fvec4
operator
!=
(
const
fvec4
&
other
)
const
{
return
_mm_cmpneq_ps
(
val
,
other
);
}
fvec4
operator
>
(
const
fvec4
&
other
)
const
{
return
_mm_cmpgt_ps
(
val
,
other
);
}
fvec4
operator
<
(
const
fvec4
&
other
)
const
{
return
_mm_cmplt_ps
(
val
,
other
);
}
fvec4
operator
>=
(
const
fvec4
&
other
)
const
{
return
_mm_cmpge_ps
(
val
,
other
);
}
fvec4
operator
<=
(
const
fvec4
&
other
)
const
{
return
_mm_cmple_ps
(
val
,
other
);
}
operator
ivec4
()
const
;
};
/**
* A four element vector of ints.
*/
class
ivec4
{
public:
__m128i
val
;
ivec4
()
{}
ivec4
(
int
v
)
:
val
(
_mm_set1_epi32
(
v
))
{}
ivec4
(
int
v1
,
int
v2
,
int
v3
,
int
v4
)
:
val
(
_mm_set_epi32
(
v4
,
v3
,
v2
,
v1
))
{}
ivec4
(
__m128i
v
)
:
val
(
v
)
{}
ivec4
(
const
int
*
v
)
:
val
(
_mm_loadu_si128
((
const
__m128i
*
)
v
))
{}
operator
__m128i
()
const
{
return
val
;
}
int
operator
[](
int
i
)
const
{
int
result
[
4
];
store
(
result
);
return
result
[
i
];
}
void
store
(
int
*
v
)
const
{
_mm_storeu_si128
((
__m128i
*
)
v
,
val
);
}
ivec4
operator
+
(
const
ivec4
&
other
)
const
{
return
_mm_add_epi32
(
val
,
other
);
}
ivec4
operator
-
(
const
ivec4
&
other
)
const
{
return
_mm_sub_epi32
(
val
,
other
);
}
ivec4
operator
*
(
const
ivec4
&
other
)
const
{
return
_mm_mul_epi32
(
val
,
other
);
}
void
operator
+=
(
const
ivec4
&
other
)
{
val
=
_mm_add_epi32
(
val
,
other
);
}
void
operator
-=
(
const
ivec4
&
other
)
{
val
=
_mm_sub_epi32
(
val
,
other
);
}
void
operator
*=
(
const
ivec4
&
other
)
{
val
=
_mm_mul_epi32
(
val
,
other
);
}
ivec4
operator
-
()
const
{
return
_mm_sub_epi32
(
_mm_set1_epi32
(
0
),
val
);
}
ivec4
operator
&
(
const
ivec4
&
other
)
const
{
return
_mm_and_si128
(
val
,
other
);
}
ivec4
operator
|
(
const
ivec4
&
other
)
const
{
return
_mm_or_si128
(
val
,
other
);
}
ivec4
operator
==
(
const
ivec4
&
other
)
const
{
return
_mm_cmpeq_epi32
(
val
,
other
);
}
ivec4
operator
!=
(
const
ivec4
&
other
)
const
{
return
_mm_xor_si128
(
*
this
==
other
,
_mm_set1_epi32
(
0xFFFFFFFF
));
}
ivec4
operator
>
(
const
ivec4
&
other
)
const
{
return
_mm_cmpgt_epi32
(
val
,
other
);
}
ivec4
operator
<
(
const
ivec4
&
other
)
const
{
return
_mm_cmplt_epi32
(
val
,
other
);
}
ivec4
operator
>=
(
const
ivec4
&
other
)
const
{
return
_mm_xor_si128
(
_mm_cmplt_epi32
(
val
,
other
),
_mm_set1_epi32
(
0xFFFFFFFF
));
}
ivec4
operator
<=
(
const
ivec4
&
other
)
const
{
return
_mm_xor_si128
(
_mm_cmpgt_epi32
(
val
,
other
),
_mm_set1_epi32
(
0xFFFFFFFF
));
}
operator
fvec4
()
const
;
};
// Conversion operators.
inline
fvec4
::
operator
ivec4
()
const
{
return
_mm_cvttps_epi32
(
val
);
}
inline
ivec4
::
operator
fvec4
()
const
{
return
_mm_cvtepi32_ps
(
val
);
}
// Functions that operate on fvec4s.
static
inline
fvec4
floor
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_floor_ps
(
v
.
val
));
}
static
inline
fvec4
ceil
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_ceil_ps
(
v
.
val
));
}
static
inline
fvec4
round
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_round_ps
(
v
.
val
,
_MM_FROUND_TO_NEAREST_INT
));
}
static
inline
fvec4
min
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
_mm_min_ps
(
v1
.
val
,
v2
.
val
));
}
static
inline
fvec4
max
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
_mm_max_ps
(
v1
.
val
,
v2
.
val
));
}
static
inline
fvec4
abs
(
const
fvec4
&
v
)
{
static
const
__m128
mask
=
_mm_castsi128_ps
(
_mm_set1_epi32
(
0x7FFFFFFF
));
return
fvec4
(
_mm_and_ps
(
v
.
val
,
mask
));
}
static
inline
fvec4
sqrt
(
const
fvec4
&
v
)
{
return
fvec4
(
_mm_sqrt_ps
(
v
.
val
));
}
static
inline
float
dot3
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
_mm_cvtss_f32
(
_mm_dp_ps
(
v1
,
v2
,
0x71
));
}
static
inline
float
dot4
(
const
fvec4
&
v1
,
const
fvec4
&
v2
)
{
return
_mm_cvtss_f32
(
_mm_dp_ps
(
v1
,
v2
,
0xF1
));
}
static
inline
void
transpose
(
fvec4
&
v1
,
fvec4
&
v2
,
fvec4
&
v3
,
fvec4
&
v4
)
{
_MM_TRANSPOSE4_PS
(
v1
,
v2
,
v3
,
v4
);
}
// Functions that operate on ivec4s.
static
inline
ivec4
min
(
const
ivec4
&
v1
,
const
ivec4
&
v2
)
{
return
ivec4
(
_mm_min_epi32
(
v1
.
val
,
v2
.
val
));
}
static
inline
ivec4
max
(
const
ivec4
&
v1
,
const
ivec4
&
v2
)
{
return
ivec4
(
_mm_max_epi32
(
v1
.
val
,
v2
.
val
));
}
static
inline
ivec4
abs
(
const
ivec4
&
v
)
{
return
ivec4
(
_mm_abs_epi32
(
v
.
val
));
}
static
inline
bool
any
(
const
ivec4
&
v
)
{
return
!
_mm_test_all_zeros
(
v
,
_mm_set1_epi32
(
0xFFFFFFFF
));
}
// Mathematical operators involving a scalar and a vector.
static
inline
fvec4
operator
+
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
+
v2
;
}
static
inline
fvec4
operator
-
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
-
v2
;
}
static
inline
fvec4
operator
*
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
*
v2
;
}
static
inline
fvec4
operator
/
(
float
v1
,
const
fvec4
&
v2
)
{
return
fvec4
(
v1
)
/
v2
;
}
// Operations for blending fvec4s based on an ivec4.
static
inline
fvec4
blend
(
const
fvec4
&
v1
,
const
fvec4
&
v2
,
const
ivec4
&
mask
)
{
return
fvec4
(
_mm_blendv_ps
(
v1
.
val
,
v2
.
val
,
_mm_castsi128_ps
(
mask
.
val
)));
}
#endif
/*OPENMM_VECTORIZE_SSE_H_*/
openmmapi/src/OSRngSeed.cpp
View file @
d89cd171
...
...
@@ -29,7 +29,6 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include <stdexcept>
#if defined(_WIN32) || defined(__CYGWIN__)
#include <windows.h>
static
HCRYPTPROV
hCryptProv
=
0
;
...
...
@@ -38,28 +37,31 @@ static HCRYPTPROV hCryptProv = 0;
#include <fcntl.h>
#include <unistd.h>
#endif
#include "openmm/OpenMMException.h"
#include "openmm/internal/OSRngSeed.h"
using
OpenMM
::
OpenMMException
;
int
osrngseed
(
void
)
{
int
value
;
#if defined(_WIN32) || defined(__CYGWIN__)
if
(
!::
CryptAcquireContextW
(
&
hCryptProv
,
0
,
0
,
PROV_RSA_FULL
,
CRYPT_VERIFYCONTEXT
|
CRYPT_SILENT
))
{
throw
std
::
runtime_error
(
"Failed to initialize Windows random API (CryptoGen)"
);
throw
OpenMMException
(
"Failed to initialize Windows random API (CryptoGen)"
);
}
if
(
!
CryptGenRandom
(
hCryptProv
,
sizeof
(
int
),
(
BYTE
*
)
&
value
))
{
::
CryptReleaseContext
(
hCryptProv
,
0
);
throw
std
::
runtime_error
(
"Failed to get random numbers"
);
throw
OpenMMException
(
"Failed to get random numbers"
);
}
if
(
!::
CryptReleaseContext
(
hCryptProv
,
0
))
{
throw
std
::
runtime_error
(
"Failed to release Windows random API context"
);
throw
OpenMMException
(
"Failed to release Windows random API context"
);
}
#else
int
m_fd
=
open
(
"/dev/urandom"
,
O_RDONLY
);
if
(
m_fd
==
-
1
)
{
throw
std
::
runtime_error
(
"Failed to open /dev/urandom"
);
throw
OpenMMException
(
"Failed to open /dev/urandom"
);
}
if
(
read
(
m_fd
,
&
value
,
sizeof
(
int
))
!=
sizeof
(
int
))
{
throw
std
::
runtime_error
(
"Failed to read bytes from /dev/urandom"
);
throw
OpenMMException
(
"Failed to read bytes from /dev/urandom"
);
}
close
(
m_fd
);
#endif
...
...
platforms/cpu/sharedTarget/CMakeLists.txt
View file @
d89cd171
FOREACH
(
file
${
SOURCE_FILES
}
)
IF
(
file MATCHES
".*Vec8.*"
)
IF
(
MSVC
)
IF
(
MSVC
)
SET_SOURCE_FILES_PROPERTIES
(
${
file
}
PROPERTIES COMPILE_FLAGS
"
${
EXTRA_COMPILE_FLAGS
}
/arch:AVX /D__AVX__"
)
ELSE
(
MSVC
)
SET_SOURCE_FILES_PROPERTIES
(
${
file
}
PROPERTIES COMPILE_FLAGS
"
${
EXTRA_COMPILE_FLAGS
}
-msse4.1 -mavx"
)
ENDIF
(
MSVC
)
ELSE
(
MSVC
)
IF
(
NOT ANDROID
)
SET_SOURCE_FILES_PROPERTIES
(
${
file
}
PROPERTIES COMPILE_FLAGS
"
${
EXTRA_COMPILE_FLAGS
}
-msse4.1 -mavx"
)
ENDIF
(
NOT ANDROID
)
ENDIF
(
MSVC
)
ELSE
(
file MATCHES
".*Vec8.*"
)
IF
(
NOT MSVC
)
SET_SOURCE_FILES_PROPERTIES
(
${
file
}
PROPERTIES COMPILE_FLAGS
"
${
EXTRA_COMPILE_FLAGS
}
-msse4.1"
)
ENDIF
(
NOT MSVC
)
IF
(
NOT MSVC
)
IF
(
NOT ANDROID
)
SET_SOURCE_FILES_PROPERTIES
(
${
file
}
PROPERTIES COMPILE_FLAGS
"
${
EXTRA_COMPILE_FLAGS
}
-msse4.1"
)
ENDIF
(
NOT ANDROID
)
ENDIF
(
NOT MSVC
)
ENDIF
(
file MATCHES
".*Vec8.*"
)
ENDFOREACH
(
file
)
ADD_LIBRARY
(
${
SHARED_TARGET
}
SHARED
${
SOURCE_FILES
}
${
SOURCE_INCLUDE_FILES
}
${
API_ABS_INCLUDE_FILES
}
)
...
...
platforms/cpu/src/CpuNeighborList.cpp
View file @
d89cd171
...
...
@@ -37,7 +37,6 @@
#include <set>
#include <map>
#include <cmath>
#include <smmintrin.h>
using
namespace
std
;
...
...
platforms/cpu/src/CpuNonbondedForceVec4.cpp
View file @
d89cd171
...
...
@@ -103,7 +103,7 @@ void CpuNonbondedForceVec4::calculateBlockIxn(int blockIndex, float* forces, dou
dEdR
=
epsSig6
*
(
12.0
f
*
sig6
-
6.0
f
);
energy
=
epsSig6
*
(
sig6
-
1.0
f
);
if
(
useSwitch
)
{
fvec4
t
=
(
r
>
switchingDistance
)
&
(
(
r
-
switchingDistance
)
*
invSwitchingInterval
);
fvec4
t
=
blend
(
0.0
f
,
(
r
-
switchingDistance
)
*
invSwitchingInterval
,
r
>
switchingDistance
);
fvec4
switchValue
=
1
+
t
*
t
*
t
*
(
-
10.0
f
+
t
*
(
15.0
f
-
t
*
6.0
f
));
fvec4
switchDeriv
=
t
*
t
*
(
-
30.0
f
+
t
*
(
60.0
f
-
t
*
30.0
f
))
*
invSwitchingInterval
;
dEdR
=
switchValue
*
dEdR
-
energy
*
switchDeriv
*
r
;
...
...
@@ -214,7 +214,7 @@ void CpuNonbondedForceVec4::calculateBlockEwaldIxn(int blockIndex, float* forces
dEdR
=
epsSig6
*
(
12.0
f
*
sig6
-
6.0
f
);
energy
=
epsSig6
*
(
sig6
-
1.0
f
);
if
(
useSwitch
)
{
fvec4
t
=
(
r
>
switchingDistance
)
&
(
(
r
-
switchingDistance
)
*
invSwitchingInterval
);
fvec4
t
=
blend
(
0.0
f
,
(
r
-
switchingDistance
)
*
invSwitchingInterval
,
r
>
switchingDistance
);
fvec4
switchValue
=
1
+
t
*
t
*
t
*
(
-
10.0
f
+
t
*
(
15.0
f
-
t
*
6.0
f
));
fvec4
switchDeriv
=
t
*
t
*
(
-
30.0
f
+
t
*
(
60.0
f
-
t
*
30.0
f
))
*
invSwitchingInterval
;
dEdR
=
switchValue
*
dEdR
-
energy
*
switchDeriv
*
r
;
...
...
platforms/cpu/src/CpuPlatform.cpp
View file @
d89cd171
...
...
@@ -36,6 +36,7 @@
#include "ReferenceConstraints.h"
#include "openmm/internal/hardware.h"
#include <sstream>
#include <stdlib.h>
using
namespace
OpenMM
;
using
namespace
std
;
...
...
@@ -93,15 +94,20 @@ bool CpuPlatform::supportsDoublePrecision() const {
}
bool
CpuPlatform
::
isProcessorSupported
()
{
// Make sure the CPU supports SSE 4.1.
int
cpuInfo
[
4
];
cpuid
(
cpuInfo
,
0
);
if
(
cpuInfo
[
0
]
>=
1
)
{
cpuid
(
cpuInfo
,
1
);
return
((
cpuInfo
[
2
]
&
((
int
)
1
<<
19
))
!=
0
);
}
return
false
;
// Make sure the CPU supports SSE 4.1 or NEON.
#ifdef __ANDROID__
uint64_t
features
=
android_getCpuFeatures
();
return
(
features
&
ANDROID_CPU_ARM_FEATURE_NEON
)
!=
0
;
#else
int
cpuInfo
[
4
];
cpuid
(
cpuInfo
,
0
);
if
(
cpuInfo
[
0
]
>=
1
)
{
cpuid
(
cpuInfo
,
1
);
return
((
cpuInfo
[
2
]
&
((
int
)
1
<<
19
))
!=
0
);
}
return
false
;
#endif
}
void
CpuPlatform
::
contextCreated
(
ContextImpl
&
context
,
const
map
<
string
,
string
>&
properties
)
const
{
...
...
platforms/opencl/include/OpenCLNonbondedUtilities.h
View file @
d89cd171
...
...
@@ -284,7 +284,8 @@ private:
std
::
map
<
std
::
string
,
std
::
string
>
kernelDefines
;
double
cutoff
;
bool
useCutoff
,
usePeriodic
,
deviceIsCpu
,
anyExclusions
,
usePadding
;
int
numForceBuffers
,
startTileIndex
,
numTiles
,
startBlockIndex
,
numBlocks
,
numForceThreadBlocks
,
forceThreadBlockSize
,
nonbondedForceGroup
;
int
numForceBuffers
,
startTileIndex
,
numTiles
,
startBlockIndex
,
numBlocks
,
numForceThreadBlocks
;
int
forceThreadBlockSize
,
interactingBlocksThreadBlockSize
,
nonbondedForceGroup
;
};
/**
...
...
platforms/opencl/src/OpenCLFFT3D.cpp
View file @
d89cd171
...
...
@@ -74,177 +74,188 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
}
cl
::
Kernel
OpenCLFFT3D
::
createKernel
(
int
xsize
,
int
ysize
,
int
zsize
,
int
&
threads
)
{
bool
loopRequired
=
(
context
.
getDevice
().
getInfo
<
CL_DEVICE_TYPE
>
()
==
CL_DEVICE_TYPE_CPU
);
stringstream
source
;
int
blocksPerGroup
=
(
loopRequired
?
1
:
max
(
1
,
256
/
zsize
));
int
stage
=
0
;
int
L
=
zsize
;
int
m
=
1
;
int
maxThreads
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
bool
isCPU
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_TYPE
>
()
==
CL_DEVICE_TYPE_CPU
;
while
(
true
)
{
bool
loopRequired
=
(
zsize
>
maxThreads
||
isCPU
);
stringstream
source
;
int
blocksPerGroup
=
(
loopRequired
?
1
:
max
(
1
,
maxThreads
/
zsize
));
int
stage
=
0
;
int
L
=
zsize
;
int
m
=
1
;
// Factor zsize, generating an appropriate block of code for each factor.
while
(
L
>
1
)
{
int
input
=
stage
%
2
;
int
output
=
1
-
input
;
int
radix
;
if
(
L
%
7
==
0
)
radix
=
7
;
else
if
(
L
%
5
==
0
)
radix
=
5
;
else
if
(
L
%
4
==
0
)
radix
=
4
;
else
if
(
L
%
3
==
0
)
radix
=
3
;
else
if
(
L
%
2
==
0
)
radix
=
2
;
else
throw
OpenMMException
(
"Illegal size for FFT: "
+
context
.
intToString
(
zsize
));
source
<<
"{
\n
"
;
L
=
L
/
radix
;
source
<<
"// Pass "
<<
(
stage
+
1
)
<<
" (radix "
<<
radix
<<
")
\n
"
;
if
(
loopRequired
)
{
source
<<
"for (int i = get_local_id(0); i < "
<<
(
L
*
m
)
<<
"; i += get_local_size(0)) {
\n
"
;
source
<<
"int base = i;
\n
"
;
}
else
{
source
<<
"if (get_local_id(0) < "
<<
(
blocksPerGroup
*
L
*
m
)
<<
") {
\n
"
;
source
<<
"int block = get_local_id(0)/"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int i = get_local_id(0)-block*"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int base = i+block*"
<<
zsize
<<
";
\n
"
;
}
source
<<
"int j = i/"
<<
m
<<
";
\n
"
;
if
(
radix
==
7
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c5 = data"
<<
input
<<
"[base+"
<<
(
5
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c6 = data"
<<
input
<<
"[base+"
<<
(
6
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c6;
\n
"
;
source
<<
"real2 d1 = c1-c6;
\n
"
;
source
<<
"real2 d2 = c2+c5;
\n
"
;
source
<<
"real2 d3 = c2-c5;
\n
"
;
source
<<
"real2 d4 = c4+c3;
\n
"
;
source
<<
"real2 d5 = c4-c3;
\n
"
;
source
<<
"real2 d6 = d2+d0;
\n
"
;
source
<<
"real2 d7 = d5+d3;
\n
"
;
source
<<
"real2 b0 = c0+d6+d4;
\n
"
;
source
<<
"real2 b1 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
-
1
)
<<
"*(d6+d4);
\n
"
;
source
<<
"real2 b2 = "
<<
context
.
doubleToString
((
2
*
cos
(
2
*
M_PI
/
7
)
-
cos
(
4
*
M_PI
/
7
)
-
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d0-d4);
\n
"
;
source
<<
"real2 b3 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
-
2
*
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d4-d2);
\n
"
;
source
<<
"real2 b4 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
-
2
*
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d2-d0);
\n
"
;
source
<<
"real2 b5 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d7+d1);
\n
"
;
source
<<
"real2 b6 = -sign*"
<<
context
.
doubleToString
((
2
*
sin
(
2
*
M_PI
/
7
)
-
sin
(
4
*
M_PI
/
7
)
+
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d1-d5);
\n
"
;
source
<<
"real2 b7 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
-
2
*
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d5-d3);
\n
"
;
source
<<
"real2 b8 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
+
2
*
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d3-d1);
\n
"
;
source
<<
"real2 t0 = b0+b1;
\n
"
;
source
<<
"real2 t1 = b2+b3;
\n
"
;
source
<<
"real2 t2 = b4-b3;
\n
"
;
source
<<
"real2 t3 = -b2-b4;
\n
"
;
source
<<
"real2 t4 = b6+b7;
\n
"
;
source
<<
"real2 t5 = b8-b7;
\n
"
;
source
<<
"real2 t6 = -b8-b6;
\n
"
;
source
<<
"real2 t7 = t0+t1;
\n
"
;
source
<<
"real2 t8 = t0+t2;
\n
"
;
source
<<
"real2 t9 = t0+t3;
\n
"
;
source
<<
"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));
\n
"
;
source
<<
"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));
\n
"
;
source
<<
"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));
\n
"
;
source
<<
"data"
<<
output
<<
"[base+6*j*"
<<
m
<<
"] = b0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
7
*
L
)
<<
"], t7-t10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9-t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8+t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8-t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+5)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
5
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9+t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+6)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
6
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t7+t10);
\n
"
;
}
else
if
(
radix
==
5
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c4;
\n
"
;
source
<<
"real2 d1 = c2+c3;
\n
"
;
source
<<
"real2 d2 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c1-c4);
\n
"
;
source
<<
"real2 d3 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c2-c3);
\n
"
;
source
<<
"real2 d4 = d0+d1;
\n
"
;
source
<<
"real2 d5 = "
<<
context
.
doubleToString
(
0.25
*
sqrt
(
5.0
))
<<
"*(d0-d1);
\n
"
;
source
<<
"real2 d6 = c0-0.25f*d4;
\n
"
;
source
<<
"real2 d7 = d6+d5;
\n
"
;
source
<<
"real2 d8 = d6-d5;
\n
"
;
string
coeff
=
context
.
doubleToString
(
sin
(
0.2
*
M_PI
)
/
sin
(
0.4
*
M_PI
));
source
<<
"real2 d9 = sign*(real2) (d2.y+"
<<
coeff
<<
"*d3.y, -d2.x-"
<<
coeff
<<
"*d3.x);
\n
"
;
source
<<
"real2 d10 = sign*(real2) ("
<<
coeff
<<
"*d2.y-d3.y, d3.x-"
<<
coeff
<<
"*d2.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+4*j*"
<<
m
<<
"] = c0+d4;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
5
*
L
)
<<
"], d7+d9);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8+d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8-d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d7-d9);
\n
"
;
}
else
if
(
radix
==
4
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c0+c2;
\n
"
;
source
<<
"real2 d1 = c0-c2;
\n
"
;
source
<<
"real2 d2 = c1+c3;
\n
"
;
source
<<
"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+3*j*"
<<
m
<<
"] = d0+d2;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
4
*
L
)
<<
"], d1+d3);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d0-d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d1-d3);
\n
"
;
}
else
if
(
radix
==
3
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c2;
\n
"
;
source
<<
"real2 d1 = c0-0.5f*d0;
\n
"
;
source
<<
"real2 d2 = sign*"
<<
context
.
doubleToString
(
sin
(
M_PI
/
3.0
))
<<
"*(real2) (c1.y-c2.y, c2.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+2*j*"
<<
m
<<
"] = c0+d0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
3
*
L
)
<<
"], d1+d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
3
*
L
)
<<
"], d1-d2);
\n
"
;
}
else
if
(
radix
==
2
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"data"
<<
output
<<
"[base+j*"
<<
m
<<
"] = c0+c1;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
2
*
L
)
<<
"], c0-c1);
\n
"
;
}
source
<<
"}
\n
"
;
m
=
m
*
radix
;
source
<<
"barrier(CLK_LOCAL_MEM_FENCE);
\n
"
;
source
<<
"}
\n
"
;
++
stage
;
}
//
Factor zsize, generating an appropriate block of code for each factor
.
//
Create the kernel
.
while
(
L
>
1
)
{
int
input
=
stage
%
2
;
int
output
=
1
-
input
;
int
radix
;
if
(
L
%
7
==
0
)
radix
=
7
;
else
if
(
L
%
5
==
0
)
radix
=
5
;
else
if
(
L
%
4
==
0
)
radix
=
4
;
else
if
(
L
%
3
==
0
)
radix
=
3
;
else
if
(
L
%
2
==
0
)
radix
=
2
;
else
throw
OpenMMException
(
"Illegal size for FFT: "
+
context
.
intToString
(
zsize
));
source
<<
"{
\n
"
;
L
=
L
/
radix
;
source
<<
"// Pass "
<<
(
stage
+
1
)
<<
" (radix "
<<
radix
<<
")
\n
"
;
if
(
loopRequired
)
{
source
<<
"for (int
i
= get_local_id(0);
i
<
"
<<
(
L
*
m
)
<<
"
;
i
+= get_local_size(0))
{
\n
"
;
source
<<
"
int base = i
;
\n
"
;
source
<<
"for (int
z
= get_local_id(0);
z
<
ZSIZE
;
z
+= get_local_size(0))
\n
"
;
source
<<
"
out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[z]
;
\n
"
;
}
else
{
source
<<
"if (get_local_id(0) < "
<<
(
blocksPerGroup
*
L
*
m
)
<<
") {
\n
"
;
source
<<
"int block = get_local_id(0)/"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int i = get_local_id(0)-block*"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int base = i+block*"
<<
zsize
<<
";
\n
"
;
}
source
<<
"int j = i/"
<<
m
<<
";
\n
"
;
if
(
radix
==
7
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c5 = data"
<<
input
<<
"[base+"
<<
(
5
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c6 = data"
<<
input
<<
"[base+"
<<
(
6
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c6;
\n
"
;
source
<<
"real2 d1 = c1-c6;
\n
"
;
source
<<
"real2 d2 = c2+c5;
\n
"
;
source
<<
"real2 d3 = c2-c5;
\n
"
;
source
<<
"real2 d4 = c4+c3;
\n
"
;
source
<<
"real2 d5 = c4-c3;
\n
"
;
source
<<
"real2 d6 = d2+d0;
\n
"
;
source
<<
"real2 d7 = d5+d3;
\n
"
;
source
<<
"real2 b0 = c0+d6+d4;
\n
"
;
source
<<
"real2 b1 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
-
1
)
<<
"*(d6+d4);
\n
"
;
source
<<
"real2 b2 = "
<<
context
.
doubleToString
((
2
*
cos
(
2
*
M_PI
/
7
)
-
cos
(
4
*
M_PI
/
7
)
-
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d0-d4);
\n
"
;
source
<<
"real2 b3 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
-
2
*
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d4-d2);
\n
"
;
source
<<
"real2 b4 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
-
2
*
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d2-d0);
\n
"
;
source
<<
"real2 b5 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d7+d1);
\n
"
;
source
<<
"real2 b6 = -sign*"
<<
context
.
doubleToString
((
2
*
sin
(
2
*
M_PI
/
7
)
-
sin
(
4
*
M_PI
/
7
)
+
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d1-d5);
\n
"
;
source
<<
"real2 b7 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
-
2
*
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d5-d3);
\n
"
;
source
<<
"real2 b8 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
+
2
*
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d3-d1);
\n
"
;
source
<<
"real2 t0 = b0+b1;
\n
"
;
source
<<
"real2 t1 = b2+b3;
\n
"
;
source
<<
"real2 t2 = b4-b3;
\n
"
;
source
<<
"real2 t3 = -b2-b4;
\n
"
;
source
<<
"real2 t4 = b6+b7;
\n
"
;
source
<<
"real2 t5 = b8-b7;
\n
"
;
source
<<
"real2 t6 = -b8-b6;
\n
"
;
source
<<
"real2 t7 = t0+t1;
\n
"
;
source
<<
"real2 t8 = t0+t2;
\n
"
;
source
<<
"real2 t9 = t0+t3;
\n
"
;
source
<<
"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));
\n
"
;
source
<<
"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));
\n
"
;
source
<<
"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));
\n
"
;
source
<<
"data"
<<
output
<<
"[base+6*j*"
<<
m
<<
"] = b0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
7
*
L
)
<<
"], t7-t10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9-t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8+t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8-t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+5)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
5
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9+t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+6)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
6
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t7+t10);
\n
"
;
}
else
if
(
radix
==
5
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c4;
\n
"
;
source
<<
"real2 d1 = c2+c3;
\n
"
;
source
<<
"real2 d2 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c1-c4);
\n
"
;
source
<<
"real2 d3 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c2-c3);
\n
"
;
source
<<
"real2 d4 = d0+d1;
\n
"
;
source
<<
"real2 d5 = "
<<
context
.
doubleToString
(
0.25
*
sqrt
(
5.0
))
<<
"*(d0-d1);
\n
"
;
source
<<
"real2 d6 = c0-0.25f*d4;
\n
"
;
source
<<
"real2 d7 = d6+d5;
\n
"
;
source
<<
"real2 d8 = d6-d5;
\n
"
;
string
coeff
=
context
.
doubleToString
(
sin
(
0.2
*
M_PI
)
/
sin
(
0.4
*
M_PI
));
source
<<
"real2 d9 = sign*(real2) (d2.y+"
<<
coeff
<<
"*d3.y, -d2.x-"
<<
coeff
<<
"*d3.x);
\n
"
;
source
<<
"real2 d10 = sign*(real2) ("
<<
coeff
<<
"*d2.y-d3.y, d3.x-"
<<
coeff
<<
"*d2.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+4*j*"
<<
m
<<
"] = c0+d4;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
5
*
L
)
<<
"], d7+d9);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8+d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8-d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d7-d9);
\n
"
;
source
<<
"if (index < XSIZE*YSIZE)
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[get_local_id(0)];
\n
"
;
}
else
if
(
radix
==
4
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c0+c2;
\n
"
;
source
<<
"real2 d1 = c0-c2;
\n
"
;
source
<<
"real2 d2 = c1+c3;
\n
"
;
source
<<
"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+3*j*"
<<
m
<<
"] = d0+d2;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
4
*
L
)
<<
"], d1+d3);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d0-d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d1-d3);
\n
"
;
map
<
string
,
string
>
replacements
;
replacements
[
"XSIZE"
]
=
context
.
intToString
(
xsize
);
replacements
[
"YSIZE"
]
=
context
.
intToString
(
ysize
);
replacements
[
"ZSIZE"
]
=
context
.
intToString
(
zsize
);
replacements
[
"BLOCKS_PER_GROUP"
]
=
context
.
intToString
(
blocksPerGroup
);
replacements
[
"M_PI"
]
=
context
.
doubleToString
(
M_PI
);
replacements
[
"COMPUTE_FFT"
]
=
source
.
str
();
replacements
[
"LOOP_REQUIRED"
]
=
(
loopRequired
?
"1"
:
"0"
);
cl
::
Program
program
=
context
.
createProgram
(
context
.
replaceStrings
(
OpenCLKernelSources
::
fft
,
replacements
));
cl
::
Kernel
kernel
(
program
,
"execFFT"
);
threads
=
(
isCPU
?
1
:
blocksPerGroup
*
zsize
);
int
kernelMaxThreads
=
kernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
());
if
(
threads
>
kernelMaxThreads
)
{
// The device can't handle this block size, so reduce it.
maxThreads
=
kernelMaxThreads
;
continue
;
}
else
if
(
radix
==
3
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c2;
\n
"
;
source
<<
"real2 d1 = c0-0.5f*d0;
\n
"
;
source
<<
"real2 d2 = sign*"
<<
context
.
doubleToString
(
sin
(
M_PI
/
3.0
))
<<
"*(real2) (c1.y-c2.y, c2.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+2*j*"
<<
m
<<
"] = c0+d0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
3
*
L
)
<<
"], d1+d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
3
*
L
)
<<
"], d1-d2);
\n
"
;
}
else
if
(
radix
==
2
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"data"
<<
output
<<
"[base+j*"
<<
m
<<
"] = c0+c1;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
2
*
L
)
<<
"], c0-c1);
\n
"
;
}
source
<<
"}
\n
"
;
m
=
m
*
radix
;
source
<<
"barrier(CLK_LOCAL_MEM_FENCE);
\n
"
;
source
<<
"}
\n
"
;
++
stage
;
}
// Create the kernel.
if
(
loopRequired
)
{
source
<<
"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[z];
\n
"
;
}
else
{
source
<<
"if (index < XSIZE*YSIZE)
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[get_local_id(0)];
\n
"
;
int
bufferSize
=
blocksPerGroup
*
zsize
*
(
context
.
getUseDoublePrecision
()
?
sizeof
(
mm_double2
)
:
sizeof
(
mm_float2
));
kernel
.
setArg
(
3
,
bufferSize
,
NULL
);
kernel
.
setArg
(
4
,
bufferSize
,
NULL
);
kernel
.
setArg
(
5
,
bufferSize
,
NULL
);
return
kernel
;
}
map
<
string
,
string
>
replacements
;
replacements
[
"XSIZE"
]
=
context
.
intToString
(
xsize
);
replacements
[
"YSIZE"
]
=
context
.
intToString
(
ysize
);
replacements
[
"ZSIZE"
]
=
context
.
intToString
(
zsize
);
replacements
[
"BLOCKS_PER_GROUP"
]
=
context
.
intToString
(
blocksPerGroup
);
replacements
[
"M_PI"
]
=
context
.
doubleToString
(
M_PI
);
replacements
[
"COMPUTE_FFT"
]
=
source
.
str
();
replacements
[
"LOOP_REQUIRED"
]
=
(
loopRequired
?
"1"
:
"0"
);
cl
::
Program
program
=
context
.
createProgram
(
context
.
replaceStrings
(
OpenCLKernelSources
::
fft
,
replacements
));
cl
::
Kernel
kernel
(
program
,
"execFFT"
);
int
bufferSize
=
blocksPerGroup
*
zsize
*
(
context
.
getUseDoublePrecision
()
?
sizeof
(
mm_double2
)
:
sizeof
(
mm_float2
));
kernel
.
setArg
(
3
,
bufferSize
,
NULL
);
kernel
.
setArg
(
4
,
bufferSize
,
NULL
);
kernel
.
setArg
(
5
,
bufferSize
,
NULL
);
threads
=
(
loopRequired
?
1
:
blocksPerGroup
*
zsize
);
return
kernel
;
}
platforms/opencl/src/OpenCLKernels.cpp
View file @
d89cd171
...
...
@@ -4820,7 +4820,7 @@ void OpenCLIntegrateVariableVerletStepKernel::initialize(const System& system, c
kernel1
=
cl
::
Kernel
(
program
,
"integrateVerletPart1"
);
kernel2
=
cl
::
Kernel
(
program
,
"integrateVerletPart2"
);
selectSizeKernel
=
cl
::
Kernel
(
program
,
"selectVerletStepSize"
);
blockSize
=
min
(
min
(
256
,
system
.
getNumParticles
()),
(
int
)
cl
.
getDevice
().
getInfo
<
CL_DEVICE_MAX
_WORK_GROUP_SIZE
>
());
blockSize
=
min
(
min
(
256
,
system
.
getNumParticles
()),
(
int
)
selectSizeKernel
.
getWorkGroupInfo
<
CL_KERNEL
_WORK_GROUP_SIZE
>
(
cl
.
getDevice
()
));
}
double
OpenCLIntegrateVariableVerletStepKernel
::
execute
(
ContextImpl
&
context
,
const
VariableVerletIntegrator
&
integrator
,
double
maxTime
)
{
...
...
@@ -4930,7 +4930,7 @@ void OpenCLIntegrateVariableLangevinStepKernel::initialize(const System& system,
params
=
new
OpenCLArray
(
cl
,
3
,
cl
.
getUseDoublePrecision
()
||
cl
.
getUseMixedPrecision
()
?
sizeof
(
cl_double
)
:
sizeof
(
cl_float
),
"langevinParams"
);
blockSize
=
min
(
256
,
system
.
getNumParticles
());
blockSize
=
max
(
blockSize
,
params
->
getSize
());
blockSize
=
min
(
blockSize
,
(
int
)
cl
.
getDevice
().
getInfo
<
CL_DEVICE_MAX
_WORK_GROUP_SIZE
>
());
blockSize
=
min
(
blockSize
,
(
int
)
selectSizeKernel
.
getWorkGroupInfo
<
CL_KERNEL
_WORK_GROUP_SIZE
>
(
cl
.
getDevice
()
));
}
double
OpenCLIntegrateVariableLangevinStepKernel
::
execute
(
ContextImpl
&
context
,
const
VariableLangevinIntegrator
&
integrator
,
double
maxTime
)
{
...
...
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
View file @
d89cd171
...
...
@@ -317,42 +317,55 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
for
(
int
i
=
0
;
i
<
(
int
)
exclusionBlocksForBlock
.
size
();
i
++
)
maxExclusions
=
(
maxExclusions
>
exclusionBlocksForBlock
[
i
].
size
()
?
maxExclusions
:
exclusionBlocksForBlock
[
i
].
size
());
defines
[
"MAX_EXCLUSIONS"
]
=
context
.
intToString
(
maxExclusions
);
defines
[
"GROUP_SIZE"
]
=
(
deviceIsCpu
?
"32"
:
"128"
);
defines
[
"BUFFER_GROUPS"
]
=
(
deviceIsCpu
?
"4"
:
"2"
);
string
file
=
(
deviceIsCpu
?
OpenCLKernelSources
::
findInteractingBlocks_cpu
:
OpenCLKernelSources
::
findInteractingBlocks
);
cl
::
Program
interactingBlocksProgram
=
context
.
createProgram
(
file
,
defines
);
findBlockBoundsKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlockBounds"
);
findBlockBoundsKernel
.
setArg
<
cl_int
>
(
0
,
context
.
getNumAtoms
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
context
.
getPosq
().
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
blockCenter
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
blockBoundingBox
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
rebuildNeighborList
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"sortBoxData"
);
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
blockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
sortedBlockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
oldPositions
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
interactionCount
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
rebuildNeighborList
->
getDeviceBuffer
());
findInteractingBlocksKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlocksWithInteractions"
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
interactionCount
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
interactingTiles
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
interactingAtoms
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
6
,
interactingTiles
->
getSize
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
7
,
startBlockIndex
);
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
8
,
numBlocks
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
sortedBlocks
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
10
,
sortedBlockCenter
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
11
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
12
,
exclusionIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
13
,
exclusionRowIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
14
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
15
,
rebuildNeighborList
->
getDeviceBuffer
());
int
groupSize
=
(
deviceIsCpu
?
32
:
128
);
while
(
true
)
{
defines
[
"GROUP_SIZE"
]
=
context
.
intToString
(
groupSize
);
cl
::
Program
interactingBlocksProgram
=
context
.
createProgram
(
file
,
defines
);
findBlockBoundsKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlockBounds"
);
findBlockBoundsKernel
.
setArg
<
cl_int
>
(
0
,
context
.
getNumAtoms
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
context
.
getPosq
().
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
blockCenter
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
blockBoundingBox
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
rebuildNeighborList
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"sortBoxData"
);
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
blockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
sortedBlockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
oldPositions
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
interactionCount
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
rebuildNeighborList
->
getDeviceBuffer
());
findInteractingBlocksKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlocksWithInteractions"
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
interactionCount
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
interactingTiles
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
interactingAtoms
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
6
,
interactingTiles
->
getSize
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
7
,
startBlockIndex
);
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
8
,
numBlocks
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
sortedBlocks
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
10
,
sortedBlockCenter
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
11
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
12
,
exclusionIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
13
,
exclusionRowIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
14
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
15
,
rebuildNeighborList
->
getDeviceBuffer
());
if
(
findInteractingBlocksKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
())
<
groupSize
)
{
// The device can't handle this block size, so reduce it.
groupSize
-=
32
;
if
(
groupSize
<
32
)
throw
OpenMMException
(
"Failed to create findInteractingBlocks kernel"
);
continue
;
}
break
;
}
interactingBlocksThreadBlockSize
=
(
deviceIsCpu
?
1
:
groupSize
);
}
}
...
...
@@ -389,7 +402,7 @@ void OpenCLNonbondedUtilities::prepareInteractions() {
context
.
executeKernel
(
sortBoxDataKernel
,
context
.
getNumAtoms
());
setPeriodicBoxSizeArg
(
context
,
findInteractingBlocksKernel
,
0
);
setInvPeriodicBoxSizeArg
(
context
,
findInteractingBlocksKernel
,
1
);
context
.
executeKernel
(
findInteractingBlocksKernel
,
context
.
getNumAtoms
(),
deviceIsCpu
?
1
:
128
);
context
.
executeKernel
(
findInteractingBlocksKernel
,
context
.
getNumAtoms
(),
interactingBlocksThreadBlockSize
);
}
void
OpenCLNonbondedUtilities
::
computeInteractions
()
{
...
...
platforms/opencl/src/OpenCLPlatform.cpp
View file @
d89cd171
...
...
@@ -32,6 +32,7 @@
#include "openmm/Context.h"
#include "openmm/System.h"
#include <algorithm>
#include <cctype>
#include <sstream>
#ifdef __APPLE__
#include "sys/sysctl.h"
...
...
@@ -39,10 +40,7 @@
using
namespace
OpenMM
;
using
std
::
map
;
using
std
::
string
;
using
std
::
stringstream
;
using
std
::
vector
;
using
namespace
std
;
#ifdef OPENMM_OPENCL_BUILDING_STATIC_LIBRARY
extern
"C"
void
registerOpenCLPlatform
()
{
...
...
platforms/opencl/src/OpenCLSort.cpp
View file @
d89cd171
...
...
@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
unsigned
int
maxGroupSize
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
int
maxSharedMem
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_LOCAL_MEM_SIZE
>
();
unsigned
int
maxLocalBuffer
=
(
unsigned
int
)
((
maxSharedMem
/
trait
->
getDataSize
())
/
2
);
isShortList
=
(
length
<=
maxLocalBuffer
);
for
(
rangeKernelSize
=
1
;
rangeKernelSize
*
2
<=
maxGroupSize
;
rangeKernelSize
*=
2
)
unsigned
int
maxRangeSize
=
std
::
min
(
maxGroupSize
,
(
unsigned
int
)
computeRangeKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
()));
unsigned
int
maxPositionsSize
=
std
::
min
(
maxGroupSize
,
(
unsigned
int
)
computeBucketPositionsKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
()));
unsigned
int
maxShortListSize
=
shortListKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
());
isShortList
=
(
length
<=
maxLocalBuffer
&&
length
<
maxShortListSize
);
for
(
rangeKernelSize
=
1
;
rangeKernelSize
*
2
<=
maxRangeSize
;
rangeKernelSize
*=
2
)
;
positionsKernelSize
=
rangeKernelSize
;
positionsKernelSize
=
std
::
min
(
rangeKernelSize
,
maxPositionsSize
)
;
sortKernelSize
=
(
isShortList
?
rangeKernelSize
:
rangeKernelSize
/
2
);
if
(
rangeKernelSize
>
length
)
rangeKernelSize
=
length
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment