"tests/vscode:/vscode.git/clone" did not exist on "9cfd4ef0746076febb589788ec47df9e2db43d65"
Commit ad3ca3e7 authored by Hang Zhang's avatar Hang Zhang
Browse files

encoding

parent a3c3d942
......@@ -14,9 +14,22 @@ import platform
import subprocess
from torch.utils.ffi import create_extension
lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
this_file = os.path.dirname(os.path.realpath(__file__))
# build kernel library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libTH.1.dylib')
os.environ['THC_LIBRARIES'] = os.path.join(lib_path,'libTHC.1.dylib')
ENCODING_LIB = os.path.join(lib_path, 'libENCODING.dylib')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libTH.so.1')
os.environ['THC_LIBRARIES'] = os.path.join(lib_path,'libTHC.so.1')
ENCODING_LIB = os.path.join(lib_path, 'libENCODING.so')
build_all_cmd = ['bash', 'encoding/make.sh']
if subprocess.call(build_all_cmd) != 0:
if subprocess.call(build_all_cmd, env=dict(os.environ)) != 0:
sys.exit(1)
sources = ['encoding/src/encoding_lib.cpp']
......@@ -24,18 +37,11 @@ headers = ['encoding/src/encoding_lib.h']
defines = [('WITH_CUDA', None)]
with_cuda = True
package_base = os.path.dirname(torch.__file__)
this_file = os.path.dirname(os.path.realpath(__file__))
include_path = [os.path.join(os.environ['HOME'],'pytorch/torch/lib/THC'),
os.path.join(package_base,'lib/include/ENCODING'),
include_path = [os.path.join(lib_path, 'include'),
os.path.join(os.environ['HOME'],'pytorch/torch/lib/THC'),
os.path.join(lib_path,'include/ENCODING'),
os.path.join(this_file,'encoding/src/')]
if platform.system() == 'Darwin':
ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.dylib')
else:
ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.so')
def make_relative_rpath(path):
if platform.system() == 'Darwin':
return '-Wl,-rpath,' + path
......@@ -52,7 +58,7 @@ ffi = create_extension(
with_cuda=with_cuda,
include_dirs = include_path,
extra_link_args = [
make_relative_rpath(os.path.join(package_base, 'lib')),
make_relative_rpath(lib_path),
ENCODING_LIB,
],
)
......
......@@ -13,10 +13,6 @@ CMAKE_POLICY(VERSION 2.8)
INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/FindTorch.cmake)
#IF(NOT Torch_FOUND)
# FIND_PACKAGE(Torch REQUIRED)
#ENDIF()
IF(NOT CUDA_FOUND)
FIND_PACKAGE(CUDA 6.5 REQUIRED)
ENDIF()
......@@ -54,6 +50,7 @@ SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
FILE(GLOB src-cuda kernel/*.cu)
MESSAGE(STATUS "Torch_INSTALL_INCLUDE:" ${Torch_INSTALL_INCLUDE})
CUDA_INCLUDE_DIRECTORIES(
${CMAKE_CURRENT_SOURCE_DIR}/kernel
${Torch_INSTALL_INCLUDE}
......
......@@ -9,7 +9,7 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
from torch.nn.modules.module import Module
import torch.nn as nn
from torch.autograd import Function
from ._ext import encoding_lib
......@@ -32,6 +32,43 @@ class aggregate(Function):
return gradA, gradR
class Aggregate(Module):
class Aggregate(nn.Module):
def forward(self, A, R):
return aggregate()(A, R)
class Encoding(nn.Module):
def __init__(self, D, K):
super(Encoding, self).__init__()
# init codewords and smoothing factor
self.D, self.K = D, K
self.codewords = nn.Parameter(torch.Tensor(K, D), requires_grad=True)
self.scale = nn.Parameter(torch.Tensor(K), requires_grad=True)
self.softmax = nn.Softmax()
self.reset_params()
def reset_params(self):
self.codewords.data.uniform_(0.0, 0.02)
self.scale.data.uniform_(0.0, 0.02)
def forward(self, X):
# input X is a 4D tensor
assert(X.dim()==4, "Encoding Layer requries 4D featuremaps!")
assert(X.size(1)==self.D,"Encoding Layer incompatible input channels!")
B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
# reshape input
X = X.view(B,D,-1).transpose(1,2)
# calculate residuals
R = X.contiguous().view(B,N,1,D).expand(B,N,K,D) - self.codewords.view(
1,1,K,D).expand(B,N,K,D)
# assignment weights
A = R
A = A.pow(2).sum(3).view(B,N,K)
A = A*self.scale.view(1,1,K).expand_as(A)
A = self.softmax(A.view(B*N,K)).view(B,N,K)
# aggregate
E = aggregate()(A, R)
return E
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' + str(self.D) + ')'
......@@ -8,7 +8,8 @@
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Custom CMake rules for PyTorch (a hacky way)
# No longer using manual way to find the library.
if(FALSE)
FILE(GLOB TORCH_LIB_HINTS
"/anaconda/lib/python3.6/site-packages/torch/lib"
"/anaconda2/lib/python3.6/site-packages/torch/lib"
......@@ -19,7 +20,12 @@ FIND_PATH(TORCH_BUILD_DIR
NAMES "THNN.h"
PATHS "${TORCH_LIB_HINTS}"
)
FIND_LIBRARY(THC_LIBRARIES NAMES THC THC.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib)
FIND_LIBRARY(TH_LIBRARIES NAMES TH TH.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib)
endif()
# Set the envrionment variable via python
SET(TORCH_BUILD_DIR "$ENV{TORCH_BUILD_DIR}")
MESSAGE(STATUS "TORCH_BUILD_DIR: " ${TORCH_BUILD_DIR})
# Find the include files
......@@ -30,6 +36,5 @@ SET(TORCH_THC_UTILS_INCLUDE_DIR "$ENV{HOME}/pytorch/torch/lib/THC")
SET(Torch_INSTALL_INCLUDE "${TORCH_BUILD_DIR}/include" ${TORCH_TH_INCLUDE_DIR} ${TORCH_THC_INCLUDE_DIR} ${TORCH_THC_UTILS_INCLUDE_DIR})
# Find the libs. We need to find libraries one by one.
FIND_LIBRARY(THC_LIBRARIES NAMES THC THC.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib)
FIND_LIBRARY(TH_LIBRARIES NAMES TH TH.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib)
SET(TH_LIBRARIES "$ENV{TH_LIBRARIES}")
SET(THC_LIBRARIES "$ENV{THC_LIBRARIES}")
......@@ -12,15 +12,22 @@ import torch
import torch.nn as nn
from torch.autograd import Variable
from encoding import Aggregate
from encoding import Encoding
from torch.autograd import gradcheck
# declare dims and variables
B, N, K, D = 1, 2, 3, 4
A = Variable(torch.randn(B,N,K).cuda(), requires_grad=True)
R = Variable(torch.randn(B,N,K,D).cuda(), requires_grad=True)
X = Variable(torch.randn(B,D,3,3).cuda(), requires_grad=True)
# check Aggregate operation
test = gradcheck(Aggregate(),(A, R), eps=1e-4, atol=1e-3)
print('Gradcheck of Aggreate() returns ', test)
# check Encoding operation
encoding = Encoding(D=D, K=K).cuda()
print(encoding)
E = encoding(X)
loss = E.view(B,-1).pow(2).sum()
loss.backward()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment