Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
a3c3d942
Commit
a3c3d942
authored
May 14, 2017
by
Hang Zhang
Browse files
aggregate
parent
c05c2a59
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
99 additions
and
38 deletions
+99
-38
encoding/__init__.py
encoding/__init__.py
+8
-4
encoding/kernel/generic/device_tensor.h
encoding/kernel/generic/device_tensor.h
+37
-0
encoding/kernel/generic/encoding_kernel.c
encoding/kernel/generic/encoding_kernel.c
+29
-18
encoding/kernel/generic/encoding_kernel.h
encoding/kernel/generic/encoding_kernel.h
+2
-2
encoding/make.sh
encoding/make.sh
+7
-0
encoding/src/encoding_lib.h
encoding/src/encoding_lib.h
+2
-2
encoding/src/generic/encoding_generic.c
encoding/src/generic/encoding_generic.c
+5
-5
test/test.py
test/test.py
+9
-7
No files found.
encoding/__init__.py
View file @
a3c3d942
...
...
@@ -16,16 +16,20 @@ from ._ext import encoding_lib
class
aggregate
(
Function
):
def
forward
(
self
,
A
,
R
):
# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
self
.
save_for_backward
(
A
,
R
)
B
,
N
,
K
,
D
=
R
.
size
()
E
=
A
.
new
(
B
,
K
,
D
)
# TODO support cpu backend
print
(
encoding_lib
)
encoding_lib
.
Encoding_Float_aggregate_forward
(
E
,
A
,
R
)
return
E
def
backward
(
self
,
E
):
# TODO FIXME this is test only
return
E
def
backward
(
self
,
gradE
):
A
,
R
=
self
.
saved_tensors
gradA
=
A
.
clone
()
gradR
=
R
.
clone
()
encoding_lib
.
Encoding_Float_aggregate_backward
(
gradA
,
gradR
,
gradE
,
A
,
R
)
return
gradA
,
gradR
class
Aggregate
(
Module
):
...
...
encoding/kernel/generic/device_tensor.h
0 → 100644
View file @
a3c3d942
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
* Created by: Hang Zhang
* ECE Department, Rutgers University
* Email: zhang.hang@rutgers.edu
* Copyright (c) 2017
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/device_tensor.h"
#else
template
<
int
Dim
>
THCDeviceTensor
<
float
,
Dim
>
devicetensor
(
THCState
*
state
,
THCTensor
*
t
)
{
if
(
!
t
)
{
return
THCDeviceTensor
<
float
,
Dim
>
();
}
int
inDim
=
THCTensor_
(
nDimension
)(
state
,
t
);
if
(
inDim
==
Dim
)
{
return
toDeviceTensor
<
float
,
Dim
>
(
state
,
t
);
}
// View in which the last dimensions are collapsed or expanded as needed
THAssert
(
THCTensor_
(
isContiguous
)(
state
,
t
));
int
size
[
Dim
];
for
(
int
i
=
0
;
i
<
Dim
||
i
<
inDim
;
++
i
)
{
if
(
i
<
Dim
&&
i
<
inDim
)
{
size
[
i
]
=
t
->
size
[
i
];
}
else
if
(
i
<
Dim
)
{
size
[
i
]
=
1
;
}
else
{
size
[
Dim
-
1
]
*=
t
->
size
[
i
];
}
}
return
THCDeviceTensor
<
float
,
Dim
>
(
THCTensor_
(
data
)(
state
,
t
),
size
);
}
#endif
encoding/kernel/generic/encoding_kernel.c
View file @
a3c3d942
...
...
@@ -17,7 +17,7 @@ __global__ void Encoding_(Aggregate_Forward_kernel) (
THCDeviceTensor
<
real
,
3
>
A
,
THCDeviceTensor
<
real
,
4
>
R
)
/*
* aggregating kernel function
* aggregating
forward
kernel function
*/
{
/* declarations of the variables */
...
...
@@ -41,7 +41,7 @@ __global__ void Encoding_(Aggregate_Forward_kernel) (
void
Encoding_
(
Aggregate_Forward
)(
THCState
*
state
,
THCTensor
*
E_
,
THCTensor
*
A_
,
THCTensor
*
R_
)
/*
* aggregating the residuals with assignment weights
* aggregating
forward
the residuals with assignment weights
*/
{
/* Check the GPU index and tensor dims*/
...
...
@@ -63,12 +63,16 @@ void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_,
THCudaCheck
(
cudaGetLastError
());
}
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
__global__
void
Encoding_
(
Aggregate_Backward_kernel
)
(
THCDeviceTensor
<
real
,
3
>
G
,
THCDeviceTensor
<
real
,
3
>
GA
,
THCDeviceTensor
<
real
,
4
>
GR
,
THCDeviceTensor
<
real
,
3
>
L
,
THCDeviceTensor
<
real
,
3
>
A
,
THCDeviceTensor
<
real
,
4
>
R
)
/*
* aggregating backward kernel function
* G (dl/dR), L (dl/dE), A
*/
{
/* declarations of the variables */
...
...
@@ -76,42 +80,49 @@ __global__ void Encoding_(Aggregate_Backward_kernel) (
real
sum
;
/* Get the index and channels */
b
=
blockIdx
.
z
;
k
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
k
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
D
=
L
.
getSize
(
2
);
/* boundary check for output */
if
(
k
>=
G
.
getSize
(
2
)
||
i
>=
G
.
getSize
(
1
))
return
;
/* boundary check for output
G \in R^{BxNxKxD}
*/
if
(
k
>=
G
R
.
getSize
(
2
)
||
i
>=
G
R
.
getSize
(
1
))
return
;
/* main operation */
sum
=
0
;
for
(
d
=
0
;
d
<
D
;
d
++
)
{
//sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
GR
[
b
][
i
][
k
][
d
]
=
L
[
b
][
k
][
d
]
*
A
[
b
][
i
][
k
];
sum
+=
L
[
b
][
k
][
d
].
ldg
()
*
R
[
b
][
i
][
k
][
d
].
ldg
();
}
G
[
b
][
i
][
k
]
=
sum
;
G
A
[
b
][
i
][
k
]
=
sum
;
}
void
Encoding_
(
Aggregate_Backward
)(
THCState
*
state
,
THCTensor
*
G_
,
THCTensor
*
L
_
,
THCTensor
*
R_
)
void
Encoding_
(
Aggregate_Backward
)(
THCState
*
state
,
THCTensor
*
G
A
_
,
THCTensor
*
GR_
,
THCTensor
*
L_
,
THCTensor
*
A
_
,
THCTensor
*
R_
)
/*
* aggregate backward to assignment weights
* G (dl/dR), L (dl/dE), A
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_
(
checkGPU
)(
state
,
3
,
G_
,
L_
,
R_
);
if
(
THCTensor_
(
nDimension
)(
state
,
G_
)
!=
3
||
THCTensor_
(
nDimension
)(
state
,
L_
)
!=
3
||
THCTensor_
(
nDimension
)(
state
,
R_
)
!=
4
)
THCTensor_
(
checkGPU
)(
state
,
5
,
GA_
,
GR_
,
L_
,
A_
,
R_
);
if
(
THCTensor_
(
nDimension
)(
state
,
GA_
)
!=
3
||
THCTensor_
(
nDimension
)(
state
,
GR_
)
!=
4
||
THCTensor_
(
nDimension
)(
state
,
L_
)
!=
3
||
THCTensor_
(
nDimension
)(
state
,
A_
)
!=
3
||
THCTensor_
(
nDimension
)(
state
,
R_
)
!=
4
)
THError
(
"Encoding: incorrect input dims.
\n
"
);
/* Device tensors */
THCDeviceTensor
<
real
,
3
>
G
=
devicetensor
<
3
>
(
state
,
G_
);
THCDeviceTensor
<
real
,
3
>
GA
=
devicetensor
<
3
>
(
state
,
GA_
);
THCDeviceTensor
<
real
,
4
>
GR
=
devicetensor
<
4
>
(
state
,
GR_
);
THCDeviceTensor
<
real
,
3
>
L
=
devicetensor
<
3
>
(
state
,
L_
);
THCDeviceTensor
<
real
,
3
>
A
=
devicetensor
<
3
>
(
state
,
A_
);
THCDeviceTensor
<
real
,
4
>
R
=
devicetensor
<
4
>
(
state
,
R_
);
/* kernel function */
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
dim3
threads
(
16
,
16
);
dim3
blocks
(
G
.
getSize
(
2
)
/
16
+
1
,
G
.
getSize
(
1
)
/
16
+
1
,
G
.
getSize
(
0
));
Encoding_
(
Aggregate_Backward_kernel
)
<<<
blocks
,
threads
,
0
,
stream
>>>
(
G
,
L
,
R
);
dim3
blocks
(
GA
.
getSize
(
2
)
/
16
+
1
,
GA
.
getSize
(
1
)
/
16
+
1
,
GA
.
getSize
(
0
));
Encoding_
(
Aggregate_Backward_kernel
)
<<<
blocks
,
threads
,
0
,
stream
>>>
(
GA
,
GR
,
L
,
A
,
R
);
THCudaCheck
(
cudaGetLastError
());
}
#endif
encoding/kernel/generic/encoding_kernel.h
View file @
a3c3d942
...
...
@@ -14,6 +14,6 @@
void
Encoding_
(
Aggregate_Forward
)(
THCState
*
state
,
THCTensor
*
E_
,
THCTensor
*
A_
,
THCTensor
*
R_
);
void
Encoding_
(
Aggregate_Backward
)(
THCState
*
state
,
THCTensor
*
G_
,
THCTensor
*
L
_
,
THCTensor
*
R_
);
void
Encoding_
(
Aggregate_Backward
)(
THCState
*
state
,
THCTensor
*
G
A
_
,
THCTensor
*
GR_
,
THCTensor
*
L_
,
THCTensor
*
A
_
,
THCTensor
*
R_
);
#endif
encoding/make.sh
0 → 100644
View file @
a3c3d942
#!/usr/bin/env bash
mkdir
-p
encoding/build
&&
cd
encoding/build
# compile and install
cmake ..
make
install
cd
..
encoding/src/encoding_lib.h
View file @
a3c3d942
...
...
@@ -22,5 +22,5 @@
int
Encoding_Float_aggregate_forward
(
THCudaTensor
*
E
,
THCudaTensor
*
A
,
THCudaTensor
*
R
);
int
Encoding_Float_aggregate_backward
(
THCudaTensor
*
G
,
THCudaTensor
*
L
,
THCudaTensor
*
R
);
int
Encoding_Float_aggregate_backward
(
THCudaTensor
*
G
A
,
THCudaTensor
*
GR
,
THCudaTensor
*
L
,
THCudaTensor
*
A
,
THCudaTensor
*
R
);
encoding/src/generic/encoding_generic.c
View file @
a3c3d942
...
...
@@ -23,15 +23,15 @@ int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A,
return
0
;
}
int
Encoding_
(
aggregate_backward
)(
THCudaTensor
*
E
,
THCudaTensor
*
A
,
THCudaTensor
*
R
)
int
Encoding_
(
aggregate_backward
)(
THCudaTensor
*
GA
,
THCudaTensor
*
GR
,
THCudaTensor
*
L
,
THCudaTensor
*
A
,
THCudaTensor
*
R
)
/*
* Aggregate operation
* Aggregate backward operation to A
* G (dl/dR), L (dl/dE), A (assignments)
*/
{
Encoding_
(
Aggregate_Backward
)(
state
,
E
,
A
,
R
);
Encoding_
(
Aggregate_Backward
)(
state
,
GA
,
GR
,
L
,
A
,
R
);
/* C function return number of the outputs */
return
0
;
}
#endif
test/test.py
View file @
a3c3d942
...
...
@@ -12,13 +12,15 @@ import torch
import
torch.nn
as
nn
from
torch.autograd
import
Variable
from
encoding
import
Aggregate
from
torch.autograd
import
gradcheck
model
=
Aggregate
()
# declare dims and variables
B
,
N
,
K
,
D
=
1
,
2
,
3
,
4
# TODO cpu test
A
=
Variable
(
torch
.
ones
(
B
,
N
,
K
).
cuda
())
R
=
Variable
(
torch
.
ones
(
B
,
N
,
K
,
D
).
cuda
())
A
=
Variable
(
torch
.
randn
(
B
,
N
,
K
).
cuda
(),
requires_grad
=
True
)
R
=
Variable
(
torch
.
randn
(
B
,
N
,
K
,
D
).
cuda
(),
requires_grad
=
True
)
# check Aggregate operation
test
=
gradcheck
(
Aggregate
(),(
A
,
R
),
eps
=
1e-4
,
atol
=
1e-3
)
print
(
'Gradcheck of Aggreate() returns '
,
test
)
E
=
model
(
A
,
R
)
print
(
E
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment