Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
56f9ec3c
Commit
56f9ec3c
authored
Mar 01, 2018
by
James Reed
Committed by
Myle Ott
Mar 01, 2018
Browse files
Use ATen built-in conv_tbc method (#66)
Remove custom ConvTBC code
parent
6e4d370a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
20 additions
and
269 deletions
+20
-269
fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp
...lib/temporal_convolution_tbc/temporal_convolution_tbc.cpp
+0
-134
fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h
.../clib/temporal_convolution_tbc/temporal_convolution_tbc.h
+0
-23
fairseq/models/fconv.py
fairseq/models/fconv.py
+8
-7
fairseq/modules/conv_tbc.py
fairseq/modules/conv_tbc.py
+2
-68
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+9
-8
setup.py
setup.py
+1
-29
No files found.
fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp
deleted
100644 → 0
View file @
6e4d370a
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdio.h>
#include <string.h>
#include <stdexcept>
#include <ATen/ATen.h>
using
at
::
Tensor
;
extern
THCState
*
state
;
at
::
Type
&
getDataType
(
const
char
*
dtype
)
{
if
(
strcmp
(
dtype
,
"torch.cuda.FloatTensor"
)
==
0
)
{
return
at
::
getType
(
at
::
kCUDA
,
at
::
kFloat
);
}
else
if
(
strcmp
(
dtype
,
"torch.FloatTensor"
)
==
0
)
{
return
at
::
getType
(
at
::
kCPU
,
at
::
kFloat
);
}
else
if
(
strcmp
(
dtype
,
"torch.cuda.DoubleTensor"
)
==
0
)
{
return
at
::
getType
(
at
::
kCUDA
,
at
::
kDouble
);
}
else
if
(
strcmp
(
dtype
,
"torch.DoubleTensor"
)
==
0
)
{
return
at
::
getType
(
at
::
kCPU
,
at
::
kDouble
);
}
else
{
throw
std
::
runtime_error
(
std
::
string
(
"Unsupported data type: "
)
+
dtype
);
}
}
inline
at
::
Tensor
t
(
at
::
Type
&
type
,
void
*
i
)
{
return
type
.
unsafeTensorFromTH
(
i
,
true
);
}
void
TemporalConvolutionTBC_forward
(
const
char
*
dtype
,
void
*
_input
,
void
*
_output
,
void
*
_weight
,
void
*
_bias
)
{
auto
&
type
=
getDataType
(
dtype
);
Tensor
input
=
t
(
type
,
_input
);
Tensor
output
=
t
(
type
,
_output
);
Tensor
weight
=
t
(
type
,
_weight
);
Tensor
bias
=
t
(
type
,
_bias
);
auto
input_size
=
input
.
sizes
();
auto
output_size
=
output
.
sizes
();
auto
ilen
=
input_size
[
0
];
auto
batchSize
=
input_size
[
1
];
auto
inputPlanes
=
input_size
[
2
];
auto
outputPlanes
=
output_size
[
2
];
auto
olen
=
output_size
[
0
];
auto
kw
=
weight
.
sizes
()[
0
];
int
pad
=
(
olen
-
ilen
+
kw
-
1
)
/
2
;
// input * weights + bias -> output_features
output
.
copy_
(
bias
.
expand
(
output
.
sizes
()));
for
(
int
k
=
0
;
k
<
kw
;
k
++
)
{
int
iShift
=
std
::
max
(
0
,
k
-
pad
);
int
oShift
=
std
::
max
(
0
,
pad
-
k
);
int
t
=
std
::
min
(
ilen
+
pad
-
k
,
olen
)
-
oShift
;
// Note: gemm assumes column-major matrices
// input is l*m (row-major)
// weight is m*r (row-major)
// output is l*r (row-major)
if
(
t
>
0
)
{
auto
W
=
weight
[
k
];
auto
I
=
input
.
narrow
(
0
,
iShift
,
t
).
view
({
t
*
batchSize
,
inputPlanes
});
auto
O
=
output
.
narrow
(
0
,
oShift
,
t
).
view
({
t
*
batchSize
,
outputPlanes
});
O
.
addmm_
(
I
,
W
);
}
}
}
void
TemporalConvolutionTBC_backward
(
const
char
*
dtype
,
void
*
_dOutput
,
void
*
_dInput
,
void
*
_dWeight
,
void
*
_dBias
,
void
*
_input
,
void
*
_weight
)
{
auto
&
type
=
getDataType
(
dtype
);
Tensor
dOutput
=
t
(
type
,
_dOutput
);
Tensor
dInput
=
t
(
type
,
_dInput
);
Tensor
dWeight
=
t
(
type
,
_dWeight
);
Tensor
dBias
=
t
(
type
,
_dBias
);
Tensor
input
=
t
(
type
,
_input
);
Tensor
weight
=
t
(
type
,
_weight
);
auto
input_size
=
input
.
sizes
();
auto
output_size
=
dOutput
.
sizes
();
auto
ilen
=
input_size
[
0
];
auto
batchSize
=
input_size
[
1
];
auto
inputPlanes
=
input_size
[
2
];
auto
outputPlanes
=
output_size
[
2
];
auto
olen
=
output_size
[
0
];
auto
kw
=
weight
.
sizes
()[
0
];
int
pad
=
(
olen
-
ilen
+
kw
-
1
)
/
2
;
for
(
int
k
=
0
;
k
<
kw
;
k
++
)
{
int
iShift
=
std
::
max
(
0
,
k
-
pad
);
int
oShift
=
std
::
max
(
0
,
pad
-
k
);
int
t
=
std
::
min
(
ilen
+
pad
-
k
,
olen
)
-
oShift
;
// dOutput * T(weight) -> dInput
if
(
t
>
0
)
{
auto
dO
=
dOutput
.
narrow
(
0
,
oShift
,
t
).
view
({
t
*
batchSize
,
outputPlanes
});
auto
dI
=
dInput
.
narrow
(
0
,
iShift
,
t
).
view
({
t
*
batchSize
,
inputPlanes
});
dI
.
addmm_
(
dO
,
weight
[
k
].
t
());
}
}
for
(
int
k
=
0
;
k
<
kw
;
k
++
)
{
int
iShift
=
std
::
max
(
0
,
k
-
pad
);
int
oShift
=
std
::
max
(
0
,
pad
-
k
);
int
t
=
std
::
min
(
ilen
+
pad
-
k
,
olen
)
-
oShift
;
// T(input) * dOutput -> dWeight
if
(
t
>
0
)
{
auto
dW
=
dWeight
[
k
];
auto
dO
=
dOutput
.
narrow
(
0
,
oShift
,
t
).
view
({
t
*
batchSize
,
outputPlanes
});
auto
I
=
input
.
narrow
(
0
,
iShift
,
t
).
view
({
t
*
batchSize
,
inputPlanes
}).
t
();
dW
.
addmm_
(
I
,
dO
);
}
}
auto
tmp
=
dOutput
.
sum
(
0
,
false
);
dBias
.
copy_
(
tmp
.
sum
(
0
));
}
fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h
deleted
100644 → 0
View file @
6e4d370a
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
void
TemporalConvolutionTBC_forward
(
const
char
*
dtype
,
void
*
input
,
void
*
output
,
void
*
weight
,
void
*
bias
);
void
TemporalConvolutionTBC_backward
(
const
char
*
dtype
,
void
*
_dOutput
,
void
*
_dInput
,
void
*
_dWeight
,
void
*
_dBias
,
void
*
_input
,
void
*
_weight
);
fairseq/models/fconv.py
View file @
56f9ec3c
...
...
@@ -91,12 +91,12 @@ class FConvEncoder(FairseqEncoder):
self
.
projections
=
nn
.
ModuleList
()
self
.
convolutions
=
nn
.
ModuleList
()
for
(
out_channels
,
kernel_size
)
in
convolutions
:
pad
=
(
kernel_size
-
1
)
/
2
self
.
projections
.
append
(
Linear
(
in_channels
,
out_channels
)
if
in_channels
!=
out_channels
else
None
)
self
.
convolutions
.
append
(
ConvTBC
(
in_channels
,
out_channels
*
2
,
kernel_size
,
padding
=
pad
,
dropout
=
dropout
))
ConvTBC
(
in_channels
,
out_channels
*
2
,
kernel_size
,
dropout
=
dropout
)
)
in_channels
=
out_channels
self
.
fc2
=
Linear
(
in_channels
,
embed_dim
)
...
...
@@ -116,6 +116,9 @@ class FConvEncoder(FairseqEncoder):
for
proj
,
conv
in
zip
(
self
.
projections
,
self
.
convolutions
):
residual
=
x
if
proj
is
None
else
proj
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
padding_l
=
(
conv
.
kernel_size
[
0
]
-
1
)
//
2
padding_r
=
conv
.
kernel_size
[
0
]
//
2
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
0
,
padding_l
,
padding_r
))
x
=
conv
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
...
...
@@ -211,12 +214,12 @@ class FConvDecoder(FairseqIncrementalDecoder):
self
.
convolutions
=
nn
.
ModuleList
()
self
.
attention
=
nn
.
ModuleList
()
for
i
,
(
out_channels
,
kernel_size
)
in
enumerate
(
convolutions
):
pad
=
kernel_size
-
1
self
.
projections
.
append
(
Linear
(
in_channels
,
out_channels
)
if
in_channels
!=
out_channels
else
None
)
self
.
convolutions
.
append
(
LinearizedConv1d
(
in_channels
,
out_channels
*
2
,
kernel_size
,
padding
=
pad
,
dropout
=
dropout
))
padding
=
(
kernel_size
-
1
),
dropout
=
dropout
)
)
self
.
attention
.
append
(
AttentionLayer
(
out_channels
,
embed_dim
)
if
attention
[
i
]
else
None
)
in_channels
=
out_channels
...
...
@@ -254,8 +257,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
conv
(
x
,
incremental_state
)
if
incremental_state
is
None
:
x
=
conv
.
remove_future_timesteps
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
# attention
...
...
fairseq/modules/conv_tbc.py
View file @
56f9ec3c
...
...
@@ -6,18 +6,10 @@
# can be found in the PATENTS file in the same directory.
import
torch
from
torch.autograd
import
Function
from
torch.nn.modules.utils
import
_single
from
fairseq
import
utils
try
:
from
fairseq
import
temporal_convolution_tbc
except
ImportError
as
e
:
import
sys
sys
.
stderr
.
write
(
'ERROR: missing temporal_convolution_tbc, run `python setup.py install`
\n
'
)
raise
e
class
ConvTBC
(
torch
.
nn
.
Module
):
"""1D convolution over an input of shape (time x batch x channel)
...
...
@@ -25,23 +17,19 @@ class ConvTBC(torch.nn.Module):
The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
padding
=
0
):
super
(
ConvTBC
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_single
(
kernel_size
)
self
.
stride
=
_single
(
stride
)
self
.
padding
=
_single
(
padding
)
assert
self
.
stride
==
(
1
,)
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
Tensor
(
self
.
kernel_size
[
0
],
in_channels
,
out_channels
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
))
def
forward
(
self
,
input
):
return
ConvTBCFunction
.
apply
(
input
.
contiguous
(),
self
.
weight
,
self
.
bias
,
self
.
padding
[
0
])
return
input
.
contiguous
().
conv_tbc
(
self
.
weight
,
self
.
bias
,
self
.
padding
[
0
])
def
__repr__
(
self
):
s
=
(
'{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
...
...
@@ -50,57 +38,3 @@ class ConvTBC(torch.nn.Module):
s
+=
', bias=False'
s
+=
')'
return
s
.
format
(
name
=
self
.
__class__
.
__name__
,
**
self
.
__dict__
)
class
ConvTBCFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
pad
):
input_size
=
input
.
size
()
weight_size
=
weight
.
size
()
kernel_size
=
weight_size
[
0
]
output
=
input
.
new
(
input_size
[
0
]
-
kernel_size
+
1
+
int
(
pad
*
2
),
input_size
[
1
],
weight_size
[
2
])
ctx
.
input_size
=
input_size
ctx
.
weight_size
=
weight_size
ctx
.
save_for_backward
(
input
,
weight
)
temporal_convolution_tbc
.
TemporalConvolutionTBC_forward
(
input
.
type
().
encode
(
'utf-8'
),
input
,
output
,
weight
,
bias
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
grad_output
=
grad_output
.
data
.
contiguous
()
grad_input
=
grad_output
.
new
(
ctx
.
input_size
).
zero_
()
grad_weight
=
grad_output
.
new
(
ctx
.
weight_size
).
zero_
()
grad_bias
=
grad_output
.
new
(
ctx
.
weight_size
[
2
])
temporal_convolution_tbc
.
TemporalConvolutionTBC_backward
(
input
.
type
().
encode
(
'utf-8'
),
grad_output
,
grad_input
,
grad_weight
,
grad_bias
,
input
,
weight
)
grad_input
=
utils
.
volatile_variable
(
grad_input
)
grad_weight
=
utils
.
volatile_variable
(
grad_weight
)
grad_bias
=
utils
.
volatile_variable
(
grad_bias
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
def
conv_tbc
(
input
,
weight
,
bias
=
None
,
stride
=
1
,
padding
=
0
):
return
ConvTBCFunction
.
apply
(
input
.
contiguous
(),
weight
,
bias
,
padding
[
0
])
fairseq/modules/linearized_convolution.py
View file @
56f9ec3c
...
...
@@ -18,6 +18,7 @@ class LinearizedConvolution(ConvTBC):
At training time, this module uses ConvTBC, which is an optimized version
of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
Note that the input order changes from training to inference.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
):
...
...
@@ -27,14 +28,20 @@ class LinearizedConvolution(ConvTBC):
def
forward
(
self
,
input
,
incremental_state
=
None
):
"""
Input: Time x Batch x Channel.
Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
Args:
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state.
"""
if
incremental_state
is
None
:
return
super
().
forward
(
input
)
output
=
super
().
forward
(
input
)
if
self
.
kernel_size
[
0
]
>
1
and
self
.
padding
[
0
]
>
0
:
# remove future timesteps added by padding
output
=
output
[:
-
self
.
padding
[
0
],
:,
:]
return
output
# reshape weight
weight
=
self
.
_get_linearized_weight
()
...
...
@@ -57,12 +64,6 @@ class LinearizedConvolution(ConvTBC):
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
return
output
.
view
(
bsz
,
1
,
-
1
)
def
remove_future_timesteps
(
self
,
x
):
"""Remove future time steps created by padding."""
if
self
.
kernel_size
[
0
]
>
1
and
self
.
padding
[
0
]
>
0
:
x
=
x
[:
-
self
.
padding
[
0
],
:,
:]
return
x
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
...
...
setup.py
View file @
56f9ec3c
...
...
@@ -5,12 +5,9 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from
setuptools
import
setup
,
find_packages
,
Extension
from
setuptools.command.build_py
import
build_py
import
sys
from
torch.utils.ffi
import
create_extension
if
sys
.
version_info
<
(
3
,):
...
...
@@ -25,6 +22,7 @@ with open('LICENSE') as f:
with
open
(
'requirements.txt'
)
as
f
:
reqs
=
f
.
read
()
bleu
=
Extension
(
'fairseq.libbleu'
,
sources
=
[
...
...
@@ -34,23 +32,6 @@ bleu = Extension(
extra_compile_args
=
[
'-std=c++11'
],
)
conv_tbc
=
create_extension
(
'fairseq.temporal_convolution_tbc'
,
relative_to
=
'fairseq'
,
headers
=
[
'fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h'
],
sources
=
[
'fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp'
],
define_macros
=
[(
'WITH_CUDA'
,
None
)],
with_cuda
=
True
,
extra_compile_args
=
[
'-std=c++11'
],
source_extension
=
'.cpp'
,
)
class
build_py_hook
(
build_py
):
def
run
(
self
):
conv_tbc
.
build
()
build_py
.
run
(
self
)
setup
(
name
=
'fairseq'
,
...
...
@@ -62,13 +43,4 @@ setup(
packages
=
find_packages
(),
ext_modules
=
[
bleu
],
test_suite
=
'tests'
,
# build and install PyTorch extensions
package_data
=
{
'fairseq'
:
[
'temporal_convolution_tbc/*.so'
],
},
include_package_data
=
True
,
cmdclass
=
{
'build_py'
:
build_py_hook
,
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment