Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
8dd870b1
Commit
8dd870b1
authored
Sep 18, 2017
by
Hang Zhang
Browse files
indent
parent
fa0e478a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
122 additions
and
122 deletions
+122
-122
.editorconfig
.editorconfig
+2
-2
build.py
build.py
+26
-26
encoding/__init__.py
encoding/__init__.py
+73
-73
setup.py
setup.py
+21
-21
No files found.
.editorconfig
View file @
8dd870b1
root = true
root = true
[*]
[*]
indent_style =
tab
indent_style =
space
indent_size =
2
indent_size =
4
build.py
View file @
8dd870b1
...
@@ -20,17 +20,17 @@ this_file = os.path.dirname(os.path.realpath(__file__))
...
@@ -20,17 +20,17 @@ this_file = os.path.dirname(os.path.realpath(__file__))
# build kernel library
# build kernel library
os
.
environ
[
'TORCH_BUILD_DIR'
]
=
lib_path
os
.
environ
[
'TORCH_BUILD_DIR'
]
=
lib_path
if
platform
.
system
()
==
'Darwin'
:
if
platform
.
system
()
==
'Darwin'
:
os
.
environ
[
'TH_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTH.1.dylib'
)
os
.
environ
[
'TH_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTH.1.dylib'
)
os
.
environ
[
'THC_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTHC.1.dylib'
)
os
.
environ
[
'THC_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTHC.1.dylib'
)
ENCODING_LIB
=
os
.
path
.
join
(
lib_path
,
'libENCODING.dylib'
)
ENCODING_LIB
=
os
.
path
.
join
(
lib_path
,
'libENCODING.dylib'
)
else
:
else
:
os
.
environ
[
'TH_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTH.so.1'
)
os
.
environ
[
'TH_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTH.so.1'
)
os
.
environ
[
'THC_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTHC.so.1'
)
os
.
environ
[
'THC_LIBRARIES'
]
=
os
.
path
.
join
(
lib_path
,
'libTHC.so.1'
)
ENCODING_LIB
=
os
.
path
.
join
(
lib_path
,
'libENCODING.so'
)
ENCODING_LIB
=
os
.
path
.
join
(
lib_path
,
'libENCODING.so'
)
build_all_cmd
=
[
'bash'
,
'encoding/make.sh'
]
build_all_cmd
=
[
'bash'
,
'encoding/make.sh'
]
if
subprocess
.
call
(
build_all_cmd
,
env
=
dict
(
os
.
environ
))
!=
0
:
if
subprocess
.
call
(
build_all_cmd
,
env
=
dict
(
os
.
environ
))
!=
0
:
sys
.
exit
(
1
)
sys
.
exit
(
1
)
sources
=
[
'encoding/src/encoding_lib.cpp'
]
sources
=
[
'encoding/src/encoding_lib.cpp'
]
headers
=
[
'encoding/src/encoding_lib.h'
]
headers
=
[
'encoding/src/encoding_lib.h'
]
...
@@ -38,29 +38,29 @@ defines = [('WITH_CUDA', None)]
...
@@ -38,29 +38,29 @@ defines = [('WITH_CUDA', None)]
with_cuda
=
True
with_cuda
=
True
include_path
=
[
os
.
path
.
join
(
lib_path
,
'include'
),
include_path
=
[
os
.
path
.
join
(
lib_path
,
'include'
),
os
.
path
.
join
(
os
.
environ
[
'HOME'
],
'pytorch/torch/lib/THC'
),
os
.
path
.
join
(
os
.
environ
[
'HOME'
],
'pytorch/torch/lib/THC'
),
os
.
path
.
join
(
lib_path
,
'include/ENCODING'
),
os
.
path
.
join
(
lib_path
,
'include/ENCODING'
),
os
.
path
.
join
(
this_file
,
'encoding/src/'
)]
os
.
path
.
join
(
this_file
,
'encoding/src/'
)]
def
make_relative_rpath
(
path
):
def
make_relative_rpath
(
path
):
if
platform
.
system
()
==
'Darwin'
:
if
platform
.
system
()
==
'Darwin'
:
return
'-Wl,-rpath,'
+
path
return
'-Wl,-rpath,'
+
path
else
:
else
:
return
'-Wl,-rpath,'
+
path
return
'-Wl,-rpath,'
+
path
ffi
=
create_extension
(
ffi
=
create_extension
(
'encoding._ext.encoding_lib'
,
'encoding._ext.encoding_lib'
,
package
=
True
,
package
=
True
,
headers
=
headers
,
headers
=
headers
,
sources
=
sources
,
sources
=
sources
,
define_macros
=
defines
,
define_macros
=
defines
,
relative_to
=
__file__
,
relative_to
=
__file__
,
with_cuda
=
with_cuda
,
with_cuda
=
with_cuda
,
include_dirs
=
include_path
,
include_dirs
=
include_path
,
extra_link_args
=
[
extra_link_args
=
[
make_relative_rpath
(
lib_path
),
make_relative_rpath
(
lib_path
),
ENCODING_LIB
,
ENCODING_LIB
,
],
],
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
encoding/__init__.py
View file @
8dd870b1
...
@@ -17,85 +17,85 @@ from torch.nn.parameter import Parameter
...
@@ -17,85 +17,85 @@ from torch.nn.parameter import Parameter
from
._ext
import
encoding_lib
from
._ext
import
encoding_lib
class
aggregate
(
Function
):
class
aggregate
(
Function
):
def
forward
(
self
,
A
,
R
):
def
forward
(
self
,
A
,
R
):
# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
self
.
save_for_backward
(
A
,
R
)
self
.
save_for_backward
(
A
,
R
)
B
,
N
,
K
,
D
=
R
.
size
()
B
,
N
,
K
,
D
=
R
.
size
()
E
=
A
.
new
(
B
,
K
,
D
)
E
=
A
.
new
(
B
,
K
,
D
)
# TODO support cpu backend
# TODO support cpu backend
if
isinstance
(
A
,
torch
.
cuda
.
FloatTensor
):
if
isinstance
(
A
,
torch
.
cuda
.
FloatTensor
):
encoding_lib
.
Encoding_Float_aggregate_forward
(
E
,
A
,
R
)
encoding_lib
.
Encoding_Float_aggregate_forward
(
E
,
A
,
R
)
elif
isinstance
(
A
,
torch
.
cuda
.
DoubleTensor
):
elif
isinstance
(
A
,
torch
.
cuda
.
DoubleTensor
):
encoding_lib
.
Encoding_Double_aggregate_forward
(
E
,
A
,
R
)
encoding_lib
.
Encoding_Double_aggregate_forward
(
E
,
A
,
R
)
else
:
else
:
raise
RuntimeError
(
'unimplemented'
)
raise
RuntimeError
(
'unimplemented'
)
return
E
return
E
def
backward
(
self
,
gradE
):
def
backward
(
self
,
gradE
):
A
,
R
=
self
.
saved_tensors
A
,
R
=
self
.
saved_tensors
gradA
=
A
.
new
().
resize_as_
(
A
)
gradA
=
A
.
new
().
resize_as_
(
A
)
gradR
=
R
.
new
().
resize_as_
(
R
)
gradR
=
R
.
new
().
resize_as_
(
R
)
if
isinstance
(
A
,
torch
.
cuda
.
FloatTensor
):
if
isinstance
(
A
,
torch
.
cuda
.
FloatTensor
):
encoding_lib
.
Encoding_Float_aggregate_backward
(
gradA
,
gradR
,
gradE
,
encoding_lib
.
Encoding_Float_aggregate_backward
(
gradA
,
gradR
,
gradE
,
A
,
R
)
A
,
R
)
elif
isinstance
(
A
,
torch
.
cuda
.
DoubleTensor
):
elif
isinstance
(
A
,
torch
.
cuda
.
DoubleTensor
):
encoding_lib
.
Encoding_Double_aggregate_backward
(
gradA
,
gradR
,
gradE
,
encoding_lib
.
Encoding_Double_aggregate_backward
(
gradA
,
gradR
,
gradE
,
A
,
R
)
A
,
R
)
else
:
else
:
raise
RuntimeError
(
'unimplemented'
)
raise
RuntimeError
(
'unimplemented'
)
return
gradA
,
gradR
return
gradA
,
gradR
class
Aggregate
(
nn
.
Module
):
class
Aggregate
(
nn
.
Module
):
def
forward
(
self
,
A
,
R
):
def
forward
(
self
,
A
,
R
):
return
aggregate
()(
A
,
R
)
return
aggregate
()(
A
,
R
)
class
Encoding
(
nn
.
Module
):
class
Encoding
(
nn
.
Module
):
def
__init__
(
self
,
D
,
K
):
def
__init__
(
self
,
D
,
K
):
super
(
Encoding
,
self
).
__init__
()
super
(
Encoding
,
self
).
__init__
()
# init codewords and smoothing factor
# init codewords and smoothing factor
self
.
D
,
self
.
K
=
D
,
K
self
.
D
,
self
.
K
=
D
,
K
self
.
codewords
=
nn
.
Parameter
(
torch
.
Tensor
(
K
,
D
),
requires_grad
=
True
)
self
.
codewords
=
nn
.
Parameter
(
torch
.
Tensor
(
K
,
D
),
requires_grad
=
True
)
self
.
scale
=
nn
.
Parameter
(
torch
.
Tensor
(
K
),
requires_grad
=
True
)
self
.
scale
=
nn
.
Parameter
(
torch
.
Tensor
(
K
),
requires_grad
=
True
)
self
.
softmax
=
nn
.
Softmax
()
self
.
softmax
=
nn
.
Softmax
()
self
.
reset_params
()
self
.
reset_params
()
def
reset_params
(
self
):
def
reset_params
(
self
):
std1
=
1.
/
((
self
.
K
*
self
.
D
)
**
(
1
/
2
))
std1
=
1.
/
((
self
.
K
*
self
.
D
)
**
(
1
/
2
))
std2
=
1.
/
((
self
.
K
)
**
(
1
/
2
))
std2
=
1.
/
((
self
.
K
)
**
(
1
/
2
))
self
.
codewords
.
data
.
uniform_
(
-
std1
,
std1
)
self
.
codewords
.
data
.
uniform_
(
-
std1
,
std1
)
self
.
scale
.
data
.
uniform_
(
-
std2
,
std2
)
self
.
scale
.
data
.
uniform_
(
-
std2
,
std2
)
def
forward
(
self
,
X
):
def
forward
(
self
,
X
):
# input X is a 4D tensor
# input X is a 4D tensor
assert
(
X
.
size
(
1
)
==
self
.
D
,
"Encoding Layer incompatible input channels!"
)
assert
(
X
.
size
(
1
)
==
self
.
D
,
"Encoding Layer incompatible input channels!"
)
unpacked
=
False
unpacked
=
False
if
X
.
dim
()
==
3
:
if
X
.
dim
()
==
3
:
unpacked
=
True
unpacked
=
True
X
=
X
.
unsqueeze
(
0
)
X
=
X
.
unsqueeze
(
0
)
B
,
N
,
K
,
D
=
X
.
size
(
0
),
X
.
size
(
2
)
*
X
.
size
(
3
),
self
.
K
,
self
.
D
B
,
N
,
K
,
D
=
X
.
size
(
0
),
X
.
size
(
2
)
*
X
.
size
(
3
),
self
.
K
,
self
.
D
# reshape input
# reshape input
X
=
X
.
view
(
B
,
D
,
-
1
).
transpose
(
1
,
2
)
X
=
X
.
view
(
B
,
D
,
-
1
).
transpose
(
1
,
2
)
# calculate residuals
# calculate residuals
R
=
X
.
contiguous
().
view
(
B
,
N
,
1
,
D
).
expand
(
B
,
N
,
K
,
D
)
-
self
.
codewords
.
view
(
R
=
X
.
contiguous
().
view
(
B
,
N
,
1
,
D
).
expand
(
B
,
N
,
K
,
D
)
-
self
.
codewords
.
view
(
1
,
1
,
K
,
D
).
expand
(
B
,
N
,
K
,
D
)
1
,
1
,
K
,
D
).
expand
(
B
,
N
,
K
,
D
)
# assignment weights
# assignment weights
A
=
R
A
=
R
A
=
A
.
pow
(
2
).
sum
(
3
).
view
(
B
,
N
,
K
)
A
=
A
.
pow
(
2
).
sum
(
3
).
view
(
B
,
N
,
K
)
A
=
A
*
self
.
scale
.
view
(
1
,
1
,
K
).
expand_as
(
A
)
A
=
A
*
self
.
scale
.
view
(
1
,
1
,
K
).
expand_as
(
A
)
A
=
self
.
softmax
(
A
.
view
(
B
*
N
,
K
)).
view
(
B
,
N
,
K
)
A
=
self
.
softmax
(
A
.
view
(
B
*
N
,
K
)).
view
(
B
,
N
,
K
)
# aggregate
# aggregate
E
=
aggregate
()(
A
,
R
)
E
=
aggregate
()(
A
,
R
)
if
unpacked
:
if
unpacked
:
E
=
E
.
squeeze
(
0
)
E
=
E
.
squeeze
(
0
)
return
E
return
E
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'('
\
return
self
.
__class__
.
__name__
+
'('
\
+
'N x '
+
str
(
self
.
D
)
+
'=>'
+
str
(
self
.
K
)
+
'x'
+
str
(
self
.
D
)
+
')'
+
'N x '
+
str
(
self
.
D
)
+
'=>'
+
str
(
self
.
K
)
+
'x'
+
str
(
self
.
D
)
+
')'
class
sum_square
(
Function
):
class
sum_square
(
Function
):
def
forward
(
ctx
,
input
):
def
forward
(
ctx
,
input
):
...
...
setup.py
View file @
8dd870b1
...
@@ -19,27 +19,27 @@ this_file = os.path.dirname(__file__)
...
@@ -19,27 +19,27 @@ this_file = os.path.dirname(__file__)
extra_compile_args
=
[
'-std=c++11'
,
'-Wno-write-strings'
]
extra_compile_args
=
[
'-std=c++11'
,
'-Wno-write-strings'
]
if
os
.
getenv
(
'PYTORCH_BINARY_BUILD'
)
and
platform
.
system
()
==
'Linux'
:
if
os
.
getenv
(
'PYTORCH_BINARY_BUILD'
)
and
platform
.
system
()
==
'Linux'
:
print
(
'PYTORCH_BINARY_BUILD found. Static linking libstdc++ on Linux'
)
print
(
'PYTORCH_BINARY_BUILD found. Static linking libstdc++ on Linux'
)
extra_compile_args
+=
[
'-static-libstdc++'
]
extra_compile_args
+=
[
'-static-libstdc++'
]
extra_link_args
+=
[
'-static-libstdc++'
]
extra_link_args
+=
[
'-static-libstdc++'
]
setup
(
setup
(
name
=
"encoding"
,
name
=
"encoding"
,
version
=
"0.0.1"
,
version
=
"0.0.1"
,
description
=
"PyTorch Encoding Layer"
,
description
=
"PyTorch Encoding Layer"
,
url
=
"https://github.com/zhanghang1989/PyTorch-Encoding-Layer"
,
url
=
"https://github.com/zhanghang1989/PyTorch-Encoding-Layer"
,
author
=
"Hang Zhang"
,
author
=
"Hang Zhang"
,
author_email
=
"zhang.hang@rutgers.edu"
,
author_email
=
"zhang.hang@rutgers.edu"
,
# Require cffi.
# Require cffi.
install_requires
=
[
"cffi>=1.0.0"
],
install_requires
=
[
"cffi>=1.0.0"
],
setup_requires
=
[
"cffi>=1.0.0"
],
setup_requires
=
[
"cffi>=1.0.0"
],
# Exclude the build files.
# Exclude the build files.
packages
=
find_packages
(
exclude
=
[
"build"
]),
packages
=
find_packages
(
exclude
=
[
"build"
]),
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
,
# Package where to put the extensions. Has to be a prefix of build.py.
# Package where to put the extensions. Has to be a prefix of build.py.
ext_package
=
""
,
ext_package
=
""
,
# Extensions to compile.
# Extensions to compile.
cffi_modules
=
[
cffi_modules
=
[
os
.
path
.
join
(
this_file
,
"build.py:ffi"
)
os
.
path
.
join
(
this_file
,
"build.py:ffi"
)
],
],
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment