Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
ad3ca3e7
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b2140a895b6530ffb5aac96a9a4ea5e06e848a83"
Commit
ad3ca3e7
authored
May 15, 2017
by
Hang Zhang
Browse files
encoding
parent
a3c3d942
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
24 deletions
+76
-24
build.py
build.py
+18
-12
encoding/CMakeLists.txt
encoding/CMakeLists.txt
+1
-4
encoding/__init__.py
encoding/__init__.py
+39
-2
encoding/cmake/FindTorch.cmake
encoding/cmake/FindTorch.cmake
+9
-4
setup.py
setup.py
+1
-1
test/test.py
test/test.py
+8
-1
No files found.
build.py
View file @
ad3ca3e7
...
@@ -14,9 +14,22 @@ import platform
...
@@ -14,9 +14,22 @@ import platform
import
subprocess
import
subprocess
from
torch.utils.ffi
import
create_extension
from
torch.utils.ffi
import
create_extension
lib_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
torch
.
__file__
),
'lib'
)
this_file
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
# build kernel library
# build kernel library
os
.
environ
[
'TORCH_BUILD_DIR'
]
=
lib_path
if
platform
.
system
()
==
'Darwin'
:
os
.
environ
[
'TH_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTH.1.dylib'
)
os
.
environ
[
'THC_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTHC.1.dylib'
)
ENCODING_LIB
=
os
.
path
.
join
(
lib_path
,
'libENCODING.dylib'
)
else
:
os
.
environ
[
'TH_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTH.so.1'
)
os
.
environ
[
'THC_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTHC.so.1'
)
ENCODING_LIB
=
os
.
path
.
join
(
lib_path
,
'libENCODING.so'
)
build_all_cmd
=
[
'bash'
,
'encoding/make.sh'
]
build_all_cmd
=
[
'bash'
,
'encoding/make.sh'
]
if
subprocess
.
call
(
build_all_cmd
)
!=
0
:
if
subprocess
.
call
(
build_all_cmd
,
env
=
dict
(
os
.
environ
)
)
!=
0
:
sys
.
exit
(
1
)
sys
.
exit
(
1
)
sources
=
[
'encoding/src/encoding_lib.cpp'
]
sources
=
[
'encoding/src/encoding_lib.cpp'
]
...
@@ -24,18 +37,11 @@ headers = ['encoding/src/encoding_lib.h']
...
@@ -24,18 +37,11 @@ headers = ['encoding/src/encoding_lib.h']
defines
=
[(
'WITH_CUDA'
,
None
)]
defines
=
[(
'WITH_CUDA'
,
None
)]
with_cuda
=
True
with_cuda
=
True
package_base
=
os
.
path
.
dirname
(
torch
.
__file__
)
include_path
=
[
os
.
path
.
join
(
lib_path
,
'include'
),
this_file
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
os
.
path
.
join
(
os
.
environ
[
'HOME'
],
'pytorch/torch/lib/THC'
),
os
.
path
.
join
(
lib_path
,
'include/ENCODING'
),
include_path
=
[
os
.
path
.
join
(
os
.
environ
[
'HOME'
],
'pytorch/torch/lib/THC'
),
os
.
path
.
join
(
package_base
,
'lib/include/ENCODING'
),
os
.
path
.
join
(
this_file
,
'encoding/src/'
)]
os
.
path
.
join
(
this_file
,
'encoding/src/'
)]
if
platform
.
system
()
==
'Darwin'
:
ENCODING_LIB
=
os
.
path
.
join
(
package_base
,
'lib/libENCODING.dylib'
)
else
:
ENCODING_LIB
=
os
.
path
.
join
(
package_base
,
'lib/libENCODING.so'
)
def
make_relative_rpath
(
path
):
def
make_relative_rpath
(
path
):
if
platform
.
system
()
==
'Darwin'
:
if
platform
.
system
()
==
'Darwin'
:
return
'-Wl,-rpath,'
+
path
return
'-Wl,-rpath,'
+
path
...
@@ -52,7 +58,7 @@ ffi = create_extension(
...
@@ -52,7 +58,7 @@ ffi = create_extension(
with_cuda
=
with_cuda
,
with_cuda
=
with_cuda
,
include_dirs
=
include_path
,
include_dirs
=
include_path
,
extra_link_args
=
[
extra_link_args
=
[
make_relative_rpath
(
os
.
path
.
join
(
package_base
,
'lib'
)
),
make_relative_rpath
(
lib_path
),
ENCODING_LIB
,
ENCODING_LIB
,
],
],
)
)
...
...
encoding/CMakeLists.txt
View file @
ad3ca3e7
...
@@ -13,10 +13,6 @@ CMAKE_POLICY(VERSION 2.8)
...
@@ -13,10 +13,6 @@ CMAKE_POLICY(VERSION 2.8)
INCLUDE
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake/FindTorch.cmake
)
INCLUDE
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake/FindTorch.cmake
)
#IF(NOT Torch_FOUND)
# FIND_PACKAGE(Torch REQUIRED)
#ENDIF()
IF
(
NOT CUDA_FOUND
)
IF
(
NOT CUDA_FOUND
)
FIND_PACKAGE
(
CUDA 6.5 REQUIRED
)
FIND_PACKAGE
(
CUDA 6.5 REQUIRED
)
ENDIF
()
ENDIF
()
...
@@ -54,6 +50,7 @@ SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
...
@@ -54,6 +50,7 @@ SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
FILE
(
GLOB src-cuda kernel/*.cu
)
FILE
(
GLOB src-cuda kernel/*.cu
)
MESSAGE
(
STATUS
"Torch_INSTALL_INCLUDE:"
${
Torch_INSTALL_INCLUDE
}
)
CUDA_INCLUDE_DIRECTORIES
(
CUDA_INCLUDE_DIRECTORIES
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernel
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernel
${
Torch_INSTALL_INCLUDE
}
${
Torch_INSTALL_INCLUDE
}
...
...
encoding/__init__.py
View file @
ad3ca3e7
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import
torch
import
torch
from
torch.nn
.modules.module
import
Module
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
._ext
import
encoding_lib
from
._ext
import
encoding_lib
...
@@ -32,6 +32,43 @@ class aggregate(Function):
...
@@ -32,6 +32,43 @@ class aggregate(Function):
return
gradA
,
gradR
return
gradA
,
gradR
class
Aggregate
(
Module
):
class
Aggregate
(
nn
.
Module
):
def
forward
(
self
,
A
,
R
):
def
forward
(
self
,
A
,
R
):
return
aggregate
()(
A
,
R
)
return
aggregate
()(
A
,
R
)
class
Encoding
(
nn
.
Module
):
def
__init__
(
self
,
D
,
K
):
super
(
Encoding
,
self
).
__init__
()
# init codewords and smoothing factor
self
.
D
,
self
.
K
=
D
,
K
self
.
codewords
=
nn
.
Parameter
(
torch
.
Tensor
(
K
,
D
),
requires_grad
=
True
)
self
.
scale
=
nn
.
Parameter
(
torch
.
Tensor
(
K
),
requires_grad
=
True
)
self
.
softmax
=
nn
.
Softmax
()
self
.
reset_params
()
def
reset_params
(
self
):
self
.
codewords
.
data
.
uniform_
(
0.0
,
0.02
)
self
.
scale
.
data
.
uniform_
(
0.0
,
0.02
)
def
forward
(
self
,
X
):
# input X is a 4D tensor
assert
(
X
.
dim
()
==
4
,
"Encoding Layer requries 4D featuremaps!"
)
assert
(
X
.
size
(
1
)
==
self
.
D
,
"Encoding Layer incompatible input channels!"
)
B
,
N
,
K
,
D
=
X
.
size
(
0
),
X
.
size
(
2
)
*
X
.
size
(
3
),
self
.
K
,
self
.
D
# reshape input
X
=
X
.
view
(
B
,
D
,
-
1
).
transpose
(
1
,
2
)
# calculate residuals
R
=
X
.
contiguous
().
view
(
B
,
N
,
1
,
D
).
expand
(
B
,
N
,
K
,
D
)
-
self
.
codewords
.
view
(
1
,
1
,
K
,
D
).
expand
(
B
,
N
,
K
,
D
)
# assignment weights
A
=
R
A
=
A
.
pow
(
2
).
sum
(
3
).
view
(
B
,
N
,
K
)
A
=
A
*
self
.
scale
.
view
(
1
,
1
,
K
).
expand_as
(
A
)
A
=
self
.
softmax
(
A
.
view
(
B
*
N
,
K
)).
view
(
B
,
N
,
K
)
# aggregate
E
=
aggregate
()(
A
,
R
)
return
E
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'('
\
+
'N x '
+
str
(
self
.
D
)
+
'=>'
+
str
(
self
.
K
)
+
'x'
+
str
(
self
.
D
)
+
')'
encoding/cmake/FindTorch.cmake
View file @
ad3ca3e7
...
@@ -8,7 +8,8 @@
...
@@ -8,7 +8,8 @@
## LICENSE file in the root directory of this source tree
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Custom CMake rules for PyTorch (a hacky way)
# No longer using manual way to find the library.
if
(
FALSE
)
FILE
(
GLOB TORCH_LIB_HINTS
FILE
(
GLOB TORCH_LIB_HINTS
"/anaconda/lib/python3.6/site-packages/torch/lib"
"/anaconda/lib/python3.6/site-packages/torch/lib"
"/anaconda2/lib/python3.6/site-packages/torch/lib"
"/anaconda2/lib/python3.6/site-packages/torch/lib"
...
@@ -19,7 +20,12 @@ FIND_PATH(TORCH_BUILD_DIR
...
@@ -19,7 +20,12 @@ FIND_PATH(TORCH_BUILD_DIR
NAMES
"THNN.h"
NAMES
"THNN.h"
PATHS
"
${
TORCH_LIB_HINTS
}
"
PATHS
"
${
TORCH_LIB_HINTS
}
"
)
)
FIND_LIBRARY
(
THC_LIBRARIES NAMES THC THC.1 PATHS
${
TORCH_BUILD_DIR
}
PATH_SUFFIXES lib
)
FIND_LIBRARY
(
TH_LIBRARIES NAMES TH TH.1 PATHS
${
TORCH_BUILD_DIR
}
PATH_SUFFIXES lib
)
endif
()
# Set the envrionment variable via python
SET
(
TORCH_BUILD_DIR
"$ENV{TORCH_BUILD_DIR}"
)
MESSAGE
(
STATUS
"TORCH_BUILD_DIR: "
${
TORCH_BUILD_DIR
}
)
MESSAGE
(
STATUS
"TORCH_BUILD_DIR: "
${
TORCH_BUILD_DIR
}
)
# Find the include files
# Find the include files
...
@@ -30,6 +36,5 @@ SET(TORCH_THC_UTILS_INCLUDE_DIR "$ENV{HOME}/pytorch/torch/lib/THC")
...
@@ -30,6 +36,5 @@ SET(TORCH_THC_UTILS_INCLUDE_DIR "$ENV{HOME}/pytorch/torch/lib/THC")
SET
(
Torch_INSTALL_INCLUDE
"
${
TORCH_BUILD_DIR
}
/include"
${
TORCH_TH_INCLUDE_DIR
}
${
TORCH_THC_INCLUDE_DIR
}
${
TORCH_THC_UTILS_INCLUDE_DIR
}
)
SET
(
Torch_INSTALL_INCLUDE
"
${
TORCH_BUILD_DIR
}
/include"
${
TORCH_TH_INCLUDE_DIR
}
${
TORCH_THC_INCLUDE_DIR
}
${
TORCH_THC_UTILS_INCLUDE_DIR
}
)
# Find the libs. We need to find libraries one by one.
# Find the libs. We need to find libraries one by one.
FIND_LIBRARY
(
THC_LIBRARIES NAMES THC THC.1 PATHS
${
TORCH_BUILD_DIR
}
PATH_SUFFIXES lib
)
SET
(
TH_LIBRARIES
"$ENV{TH_LIBRARIES}"
)
FIND_LIBRARY
(
TH_LIBRARIES NAMES TH TH.1 PATHS
${
TORCH_BUILD_DIR
}
PATH_SUFFIXES lib
)
SET
(
THC_LIBRARIES
"$ENV{THC_LIBRARIES}"
)
setup.py
View file @
ad3ca3e7
...
@@ -35,7 +35,7 @@ setup(
...
@@ -35,7 +35,7 @@ setup(
setup_requires
=
[
"cffi>=1.0.0"
],
setup_requires
=
[
"cffi>=1.0.0"
],
# Exclude the build files.
# Exclude the build files.
packages
=
find_packages
(
exclude
=
[
"build"
]),
packages
=
find_packages
(
exclude
=
[
"build"
]),
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
,
# Package where to put the extensions. Has to be a prefix of build.py.
# Package where to put the extensions. Has to be a prefix of build.py.
ext_package
=
""
,
ext_package
=
""
,
# Extensions to compile.
# Extensions to compile.
...
...
test/test.py
View file @
ad3ca3e7
...
@@ -12,15 +12,22 @@ import torch
...
@@ -12,15 +12,22 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
encoding
import
Aggregate
from
encoding
import
Aggregate
from
encoding
import
Encoding
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
# declare dims and variables
# declare dims and variables
B
,
N
,
K
,
D
=
1
,
2
,
3
,
4
B
,
N
,
K
,
D
=
1
,
2
,
3
,
4
A
=
Variable
(
torch
.
randn
(
B
,
N
,
K
).
cuda
(),
requires_grad
=
True
)
A
=
Variable
(
torch
.
randn
(
B
,
N
,
K
).
cuda
(),
requires_grad
=
True
)
R
=
Variable
(
torch
.
randn
(
B
,
N
,
K
,
D
).
cuda
(),
requires_grad
=
True
)
R
=
Variable
(
torch
.
randn
(
B
,
N
,
K
,
D
).
cuda
(),
requires_grad
=
True
)
X
=
Variable
(
torch
.
randn
(
B
,
D
,
3
,
3
).
cuda
(),
requires_grad
=
True
)
# check Aggregate operation
# check Aggregate operation
test
=
gradcheck
(
Aggregate
(),(
A
,
R
),
eps
=
1e-4
,
atol
=
1e-3
)
test
=
gradcheck
(
Aggregate
(),(
A
,
R
),
eps
=
1e-4
,
atol
=
1e-3
)
print
(
'Gradcheck of Aggreate() returns '
,
test
)
print
(
'Gradcheck of Aggreate() returns '
,
test
)
# check Encoding operation
encoding
=
Encoding
(
D
=
D
,
K
=
K
).
cuda
()
print
(
encoding
)
E
=
encoding
(
X
)
loss
=
E
.
view
(
B
,
-
1
).
pow
(
2
).
sum
()
loss
.
backward
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment