"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b2140a895b6530ffb5aac96a9a4ea5e06e848a83"
Commit ad3ca3e7 authored by Hang Zhang's avatar Hang Zhang
Browse files

encoding

parent a3c3d942
...@@ -14,9 +14,22 @@ import platform ...@@ -14,9 +14,22 @@ import platform
import subprocess import subprocess
from torch.utils.ffi import create_extension from torch.utils.ffi import create_extension
lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib')
this_file = os.path.dirname(os.path.realpath(__file__))
# build kernel library # build kernel library
os.environ['TORCH_BUILD_DIR'] = lib_path
if platform.system() == 'Darwin':
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libTH.1.dylib')
os.environ['THC_LIBRARIES'] = os.path.join(lib_path,'libTHC.1.dylib')
ENCODING_LIB = os.path.join(lib_path, 'libENCODING.dylib')
else:
os.environ['TH_LIBRARIES'] = os.path.join(lib_path,'libTH.so.1')
os.environ['THC_LIBRARIES'] = os.path.join(lib_path,'libTHC.so.1')
ENCODING_LIB = os.path.join(lib_path, 'libENCODING.so')
build_all_cmd = ['bash', 'encoding/make.sh'] build_all_cmd = ['bash', 'encoding/make.sh']
if subprocess.call(build_all_cmd) != 0: if subprocess.call(build_all_cmd, env=dict(os.environ)) != 0:
sys.exit(1) sys.exit(1)
sources = ['encoding/src/encoding_lib.cpp'] sources = ['encoding/src/encoding_lib.cpp']
...@@ -24,18 +37,11 @@ headers = ['encoding/src/encoding_lib.h'] ...@@ -24,18 +37,11 @@ headers = ['encoding/src/encoding_lib.h']
defines = [('WITH_CUDA', None)] defines = [('WITH_CUDA', None)]
with_cuda = True with_cuda = True
package_base = os.path.dirname(torch.__file__) include_path = [os.path.join(lib_path, 'include'),
this_file = os.path.dirname(os.path.realpath(__file__)) os.path.join(os.environ['HOME'],'pytorch/torch/lib/THC'),
os.path.join(lib_path,'include/ENCODING'),
include_path = [os.path.join(os.environ['HOME'],'pytorch/torch/lib/THC'),
os.path.join(package_base,'lib/include/ENCODING'),
os.path.join(this_file,'encoding/src/')] os.path.join(this_file,'encoding/src/')]
if platform.system() == 'Darwin':
ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.dylib')
else:
ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.so')
def make_relative_rpath(path): def make_relative_rpath(path):
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
return '-Wl,-rpath,' + path return '-Wl,-rpath,' + path
...@@ -52,7 +58,7 @@ ffi = create_extension( ...@@ -52,7 +58,7 @@ ffi = create_extension(
with_cuda=with_cuda, with_cuda=with_cuda,
include_dirs = include_path, include_dirs = include_path,
extra_link_args = [ extra_link_args = [
make_relative_rpath(os.path.join(package_base, 'lib')), make_relative_rpath(lib_path),
ENCODING_LIB, ENCODING_LIB,
], ],
) )
......
...@@ -13,10 +13,6 @@ CMAKE_POLICY(VERSION 2.8) ...@@ -13,10 +13,6 @@ CMAKE_POLICY(VERSION 2.8)
INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/FindTorch.cmake) INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/FindTorch.cmake)
#IF(NOT Torch_FOUND)
# FIND_PACKAGE(Torch REQUIRED)
#ENDIF()
IF(NOT CUDA_FOUND) IF(NOT CUDA_FOUND)
FIND_PACKAGE(CUDA 6.5 REQUIRED) FIND_PACKAGE(CUDA 6.5 REQUIRED)
ENDIF() ENDIF()
...@@ -54,6 +50,7 @@ SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") ...@@ -54,6 +50,7 @@ SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
FILE(GLOB src-cuda kernel/*.cu) FILE(GLOB src-cuda kernel/*.cu)
MESSAGE(STATUS "Torch_INSTALL_INCLUDE:" ${Torch_INSTALL_INCLUDE})
CUDA_INCLUDE_DIRECTORIES( CUDA_INCLUDE_DIRECTORIES(
${CMAKE_CURRENT_SOURCE_DIR}/kernel ${CMAKE_CURRENT_SOURCE_DIR}/kernel
${Torch_INSTALL_INCLUDE} ${Torch_INSTALL_INCLUDE}
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch import torch
from torch.nn.modules.module import Module import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from ._ext import encoding_lib from ._ext import encoding_lib
...@@ -32,6 +32,43 @@ class aggregate(Function): ...@@ -32,6 +32,43 @@ class aggregate(Function):
return gradA, gradR return gradA, gradR
class Aggregate(Module): class Aggregate(nn.Module):
def forward(self, A, R): def forward(self, A, R):
return aggregate()(A, R) return aggregate()(A, R)
class Encoding(nn.Module):
def __init__(self, D, K):
super(Encoding, self).__init__()
# init codewords and smoothing factor
self.D, self.K = D, K
self.codewords = nn.Parameter(torch.Tensor(K, D), requires_grad=True)
self.scale = nn.Parameter(torch.Tensor(K), requires_grad=True)
self.softmax = nn.Softmax()
self.reset_params()
def reset_params(self):
self.codewords.data.uniform_(0.0, 0.02)
self.scale.data.uniform_(0.0, 0.02)
def forward(self, X):
# input X is a 4D tensor
assert(X.dim()==4, "Encoding Layer requries 4D featuremaps!")
assert(X.size(1)==self.D,"Encoding Layer incompatible input channels!")
B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
# reshape input
X = X.view(B,D,-1).transpose(1,2)
# calculate residuals
R = X.contiguous().view(B,N,1,D).expand(B,N,K,D) - self.codewords.view(
1,1,K,D).expand(B,N,K,D)
# assignment weights
A = R
A = A.pow(2).sum(3).view(B,N,K)
A = A*self.scale.view(1,1,K).expand_as(A)
A = self.softmax(A.view(B*N,K)).view(B,N,K)
# aggregate
E = aggregate()(A, R)
return E
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' + str(self.D) + ')'
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Custom CMake rules for PyTorch (a hacky way) # No longer using manual way to find the library.
if(FALSE)
FILE(GLOB TORCH_LIB_HINTS FILE(GLOB TORCH_LIB_HINTS
"/anaconda/lib/python3.6/site-packages/torch/lib" "/anaconda/lib/python3.6/site-packages/torch/lib"
"/anaconda2/lib/python3.6/site-packages/torch/lib" "/anaconda2/lib/python3.6/site-packages/torch/lib"
...@@ -19,7 +20,12 @@ FIND_PATH(TORCH_BUILD_DIR ...@@ -19,7 +20,12 @@ FIND_PATH(TORCH_BUILD_DIR
NAMES "THNN.h" NAMES "THNN.h"
PATHS "${TORCH_LIB_HINTS}" PATHS "${TORCH_LIB_HINTS}"
) )
FIND_LIBRARY(THC_LIBRARIES NAMES THC THC.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib)
FIND_LIBRARY(TH_LIBRARIES NAMES TH TH.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib)
endif()
# Set the envrionment variable via python
SET(TORCH_BUILD_DIR "$ENV{TORCH_BUILD_DIR}")
MESSAGE(STATUS "TORCH_BUILD_DIR: " ${TORCH_BUILD_DIR}) MESSAGE(STATUS "TORCH_BUILD_DIR: " ${TORCH_BUILD_DIR})
# Find the include files # Find the include files
...@@ -30,6 +36,5 @@ SET(TORCH_THC_UTILS_INCLUDE_DIR "$ENV{HOME}/pytorch/torch/lib/THC") ...@@ -30,6 +36,5 @@ SET(TORCH_THC_UTILS_INCLUDE_DIR "$ENV{HOME}/pytorch/torch/lib/THC")
SET(Torch_INSTALL_INCLUDE "${TORCH_BUILD_DIR}/include" ${TORCH_TH_INCLUDE_DIR} ${TORCH_THC_INCLUDE_DIR} ${TORCH_THC_UTILS_INCLUDE_DIR}) SET(Torch_INSTALL_INCLUDE "${TORCH_BUILD_DIR}/include" ${TORCH_TH_INCLUDE_DIR} ${TORCH_THC_INCLUDE_DIR} ${TORCH_THC_UTILS_INCLUDE_DIR})
# Find the libs. We need to find libraries one by one. # Find the libs. We need to find libraries one by one.
FIND_LIBRARY(THC_LIBRARIES NAMES THC THC.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib) SET(TH_LIBRARIES "$ENV{TH_LIBRARIES}")
FIND_LIBRARY(TH_LIBRARIES NAMES TH TH.1 PATHS ${TORCH_BUILD_DIR} PATH_SUFFIXES lib) SET(THC_LIBRARIES "$ENV{THC_LIBRARIES}")
...@@ -35,7 +35,7 @@ setup( ...@@ -35,7 +35,7 @@ setup(
setup_requires=["cffi>=1.0.0"], setup_requires=["cffi>=1.0.0"],
# Exclude the build files. # Exclude the build files.
packages=find_packages(exclude=["build"]), packages=find_packages(exclude=["build"]),
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
# Package where to put the extensions. Has to be a prefix of build.py. # Package where to put the extensions. Has to be a prefix of build.py.
ext_package="", ext_package="",
# Extensions to compile. # Extensions to compile.
......
...@@ -12,15 +12,22 @@ import torch ...@@ -12,15 +12,22 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from encoding import Aggregate from encoding import Aggregate
from encoding import Encoding
from torch.autograd import gradcheck from torch.autograd import gradcheck
# declare dims and variables # declare dims and variables
B, N, K, D = 1, 2, 3, 4 B, N, K, D = 1, 2, 3, 4
A = Variable(torch.randn(B,N,K).cuda(), requires_grad=True) A = Variable(torch.randn(B,N,K).cuda(), requires_grad=True)
R = Variable(torch.randn(B,N,K,D).cuda(), requires_grad=True) R = Variable(torch.randn(B,N,K,D).cuda(), requires_grad=True)
X = Variable(torch.randn(B,D,3,3).cuda(), requires_grad=True)
# check Aggregate operation # check Aggregate operation
test = gradcheck(Aggregate(),(A, R), eps=1e-4, atol=1e-3) test = gradcheck(Aggregate(),(A, R), eps=1e-4, atol=1e-3)
print('Gradcheck of Aggreate() returns ', test) print('Gradcheck of Aggreate() returns ', test)
# check Encoding operation
encoding = Encoding(D=D, K=K).cuda()
print(encoding)
E = encoding(X)
loss = E.view(B,-1).pow(2).sum()
loss.backward()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment