Commit 55dbd840 authored by Hang Zhang's avatar Hang Zhang
Browse files

tested

parent 984cce35
...@@ -31,19 +31,22 @@ else: ...@@ -31,19 +31,22 @@ else:
ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.so') ENCODING_LIB = os.path.join(package_base, 'lib/libENCODING.so')
def make_relative_rpath(path): def make_relative_rpath(path):
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
return '-Wl,-rpath,' + path return '-Wl,-rpath,' + path
else: else:
return '-Wl,-rpath,' + path return '-Wl,-rpath,' + path
extra_link_args = []
ffi = create_extension( ffi = create_extension(
'encoding._ext.encoding_lib', 'encoding._ext.encoding_lib',
package=True, package=True,
headers=headers, headers=headers,
sources=sources, sources=sources,
define_macros=defines, define_macros=defines,
relative_to=__file__, relative_to=__file__,
with_cuda=with_cuda, with_cuda=with_cuda,
include_dirs = include_path, include_dirs = include_path,
extra_link_args = [ extra_link_args = [
make_relative_rpath(os.path.join(package_base, 'lib')), make_relative_rpath(os.path.join(package_base, 'lib')),
......
...@@ -49,12 +49,13 @@ IF(NOT ENCODING_INSTALL_LIB_SUBDIR) ...@@ -49,12 +49,13 @@ IF(NOT ENCODING_INSTALL_LIB_SUBDIR)
ENDIF() ENDIF()
SET(CMAKE_MACOSX_RPATH 1) SET(CMAKE_MACOSX_RPATH 1)
#SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
FILE(GLOB src-cuda kernel/*.cu) FILE(GLOB src-cuda kernel/*.cu)
CUDA_INCLUDE_DIRECTORIES( CUDA_INCLUDE_DIRECTORIES(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/kernel
${Torch_INSTALL_INCLUDE} ${Torch_INSTALL_INCLUDE}
) )
CUDA_ADD_LIBRARY(ENCODING SHARED ${src-cuda}) CUDA_ADD_LIBRARY(ENCODING SHARED ${src-cuda})
...@@ -63,11 +64,6 @@ IF(MSVC) ...@@ -63,11 +64,6 @@ IF(MSVC)
SET_TARGET_PROPERTIES(ENCODING PROPERTIES PREFIX "lib" IMPORT_PREFIX "lib") SET_TARGET_PROPERTIES(ENCODING PROPERTIES PREFIX "lib" IMPORT_PREFIX "lib")
ENDIF() ENDIF()
INCLUDE_DIRECTORIES(
./include
${CMAKE_CURRENT_SOURCE_DIR}
${Torch_INSTALL_INCLUDE}
)
TARGET_LINK_LIBRARIES(ENCODING TARGET_LINK_LIBRARIES(ENCODING
${THC_LIBRARIES} ${THC_LIBRARIES}
${TH_LIBRARIES} ${TH_LIBRARIES}
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
import torch import torch
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
from torch.autograd import Function
from ._ext import encoding_lib from ._ext import encoding_lib
class aggregate(Function): class aggregate(Function):
......
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/ */
#ifndef THC_GENERIC_FILE #ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "kernel/generic/encoding_kernel.c" #define THC_GENERIC_FILE "generic/encoding_kernel.c"
#else #else
/*
template <int Dim> template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) { THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
if (!t) { if (!t) {
...@@ -36,7 +37,7 @@ THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) { ...@@ -36,7 +37,7 @@ THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
} }
return THCDeviceTensor<float, Dim>(THCTensor_(data)(state, t), size); return THCDeviceTensor<float, Dim>(THCTensor_(data)(state, t), size);
} }
*/
__global__ void Encoding_(Aggregate_Forward_kernel) ( __global__ void Encoding_(Aggregate_Forward_kernel) (
THCDeviceTensor<real, 3> E, THCDeviceTensor<real, 3> E,
THCDeviceTensor<real, 3> A, THCDeviceTensor<real, 3> A,
...@@ -71,7 +72,7 @@ void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_, THCTensor *A_, ...@@ -71,7 +72,7 @@ void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_, THCTensor *A_,
if (THCTensor_(nDimension)(state, E_) != 3 || if (THCTensor_(nDimension)(state, E_) != 3 ||
THCTensor_(nDimension)(state, A_) != 3 || THCTensor_(nDimension)(state, A_) != 3 ||
THCTensor_(nDimension)(state, R_) != 4) THCTensor_(nDimension)(state, R_) != 4)
perror("Encoding: incorrect input dims. \n"); THError("Encoding: incorrect input dims. \n");
/* Device tensors */ /* Device tensors */
THCDeviceTensor<real, 3> E = devicetensor<3>(state, E_); THCDeviceTensor<real, 3> E = devicetensor<3>(state, E_);
THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_); THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_);
......
...@@ -12,15 +12,48 @@ ...@@ -12,15 +12,48 @@
#include "THCDeviceTensor.cuh" #include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh" #include "THCDeviceTensorUtils.cuh"
#include "thc_encoding.h"
// this symbol will be resolved automatically from PyTorch libs // this symbol will be resolved automatically from PyTorch libs
extern THCState *state; extern THCState *state;
//#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
//#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)
#define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME) #define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME)
#define THCTensor TH_CONCAT_3(TH,CReal,Tensor) #define THCTensor TH_CONCAT_3(TH,CReal,Tensor)
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME) #define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)
template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCudaTensor *t) {
if (!t) {
return THCDeviceTensor<float, Dim>();
}
int inDim = THCudaTensor_nDimension(state, t);
if (inDim == Dim) {
return toDeviceTensor<float, Dim>(state, t);
}
// View in which the last dimensions are collapsed or expanded as needed
THAssert(THCudaTensor_isContiguous(state, t));
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = t->size[i];
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= t->size[i];
}
}
return THCDeviceTensor<float, Dim>(THCudaTensor_data(state, t), size);
}
#ifdef __cplusplus
extern "C" {
#endif
#include "generic/encoding_kernel.c" #include "generic/encoding_kernel.c"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#ifdef __cplusplus
}
#endif
...@@ -15,12 +15,17 @@ ...@@ -15,12 +15,17 @@
// this symbol will be resolved automatically from PyTorch libs // this symbol will be resolved automatically from PyTorch libs
extern THCState *state; extern THCState *state;
//#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
//#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)
#define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME) #define Encoding_(NAME) TH_CONCAT_4(Encoding_, Real, _, NAME)
#define THCTensor TH_CONCAT_3(TH,CReal,Tensor) #define THCTensor TH_CONCAT_3(TH,CReal,Tensor)
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME) #define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)
#ifdef __cplusplus
extern "C" {
#endif
#include "generic/encoding_kernel.h" #include "generic/encoding_kernel.h"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#ifdef __cplusplus
}
#endif
...@@ -13,5 +13,13 @@ ...@@ -13,5 +13,13 @@
extern THCState *state; extern THCState *state;
#ifdef __cplusplus
extern "C" {
#endif
#include "generic/encoding_generic.c" #include "generic/encoding_generic.c"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#ifdef __cplusplus
}
#endif
...@@ -17,22 +17,29 @@ import build ...@@ -17,22 +17,29 @@ import build
this_file = os.path.dirname(__file__) this_file = os.path.dirname(__file__)
extra_compile_args = ['-std=c++11', '-Wno-write-strings']
if os.getenv('PYTORCH_BINARY_BUILD') and platform.system() == 'Linux':
print('PYTORCH_BINARY_BUILD found. Static linking libstdc++ on Linux')
extra_compile_args += ['-static-libstdc++']
extra_link_args += ['-static-libstdc++']
setup( setup(
name="encoding", name="encoding",
version="0.0.1", version="0.0.1",
description="PyTorch Encoding Layer", description="PyTorch Encoding Layer",
url="https://github.com/zhanghang1989/PyTorch-Encoding-Layer", url="https://github.com/zhanghang1989/PyTorch-Encoding-Layer",
author="Hang Zhang", author="Hang Zhang",
author_email="zhang.hang@rutgers.edu", author_email="zhang.hang@rutgers.edu",
# Require cffi. # Require cffi.
install_requires=["cffi>=1.0.0"], install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"], setup_requires=["cffi>=1.0.0"],
# Exclude the build files. # Exclude the build files.
packages=find_packages(exclude=["build"]), packages=find_packages(exclude=["build"]),
# Package where to put the extensions. Has to be a prefix of build.py. extra_compile_args=extra_compile_args,
ext_package="", # Package where to put the extensions. Has to be a prefix of build.py.
# Extensions to compile. ext_package="",
cffi_modules=[ # Extensions to compile.
os.path.join(this_file, "build.py:ffi") cffi_modules=[
], os.path.join(this_file, "build.py:ffi")
],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment