Commit 56f9ec3c authored by James Reed's avatar James Reed Committed by Myle Ott
Browse files

Use ATen built-in conv_tbc method (#66)

Remove custom ConvTBC code
parent 6e4d370a
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdio.h>
#include <string.h>
#include <stdexcept>
#include <ATen/ATen.h>
using at::Tensor;
extern THCState* state;
at::Type& getDataType(const char* dtype) {
if (strcmp(dtype, "torch.cuda.FloatTensor") == 0) {
return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat);
} else if (strcmp(dtype, "torch.cuda.DoubleTensor") == 0) {
return at::getType(at::kCUDA, at::kDouble);
} else if (strcmp(dtype, "torch.DoubleTensor") == 0) {
return at::getType(at::kCPU, at::kDouble);
} else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
}
}
inline at::Tensor t(at::Type& type, void* i) {
return type.unsafeTensorFromTH(i, true);
}
void TemporalConvolutionTBC_forward(
const char* dtype,
void* _input,
void* _output,
void* _weight,
void* _bias)
{
auto& type = getDataType(dtype);
Tensor input = t(type, _input);
Tensor output = t(type, _output);
Tensor weight = t(type, _weight);
Tensor bias = t(type, _bias);
auto input_size = input.sizes();
auto output_size = output.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
// input * weights + bias -> output_features
output.copy_(bias.expand(output.sizes()));
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// Note: gemm assumes column-major matrices
// input is l*m (row-major)
// weight is m*r (row-major)
// output is l*r (row-major)
if (t > 0) {
auto W = weight[k];
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
O.addmm_(I, W);
}
}
}
void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight)
{
auto& type = getDataType(dtype);
Tensor dOutput = t(type, _dOutput);
Tensor dInput = t(type, _dInput);
Tensor dWeight = t(type, _dWeight);
Tensor dBias = t(type, _dBias);
Tensor input = t(type, _input);
Tensor weight = t(type, _weight);
auto input_size = input.sizes();
auto output_size = dOutput.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// dOutput * T(weight) -> dInput
if (t > 0) {
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
dI.addmm_(dO, weight[k].t());
}
}
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// T(input) * dOutput -> dWeight
if (t > 0) {
auto dW = dWeight[k];
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
dW.addmm_(I, dO);
}
}
auto tmp = dOutput.sum(0, false);
dBias.copy_(tmp.sum(0));
}
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
void TemporalConvolutionTBC_forward(
const char* dtype,
void* input,
void* output,
void* weight,
void* bias);
void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight);
...@@ -91,12 +91,12 @@ class FConvEncoder(FairseqEncoder): ...@@ -91,12 +91,12 @@ class FConvEncoder(FairseqEncoder):
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
for (out_channels, kernel_size) in convolutions: for (out_channels, kernel_size) in convolutions:
pad = (kernel_size - 1) / 2
self.projections.append(Linear(in_channels, out_channels) self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None) if in_channels != out_channels else None)
self.convolutions.append( self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, padding=pad, ConvTBC(in_channels, out_channels * 2, kernel_size,
dropout=dropout)) dropout=dropout)
)
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
...@@ -116,6 +116,9 @@ class FConvEncoder(FairseqEncoder): ...@@ -116,6 +116,9 @@ class FConvEncoder(FairseqEncoder):
for proj, conv in zip(self.projections, self.convolutions): for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x) x = conv(x)
x = F.glu(x, dim=2) x = F.glu(x, dim=2)
x = (x + residual) * math.sqrt(0.5) x = (x + residual) * math.sqrt(0.5)
...@@ -211,12 +214,12 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -211,12 +214,12 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList() self.attention = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions): for i, (out_channels, kernel_size) in enumerate(convolutions):
pad = kernel_size - 1
self.projections.append(Linear(in_channels, out_channels) self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None) if in_channels != out_channels else None)
self.convolutions.append( self.convolutions.append(
LinearizedConv1d(in_channels, out_channels * 2, kernel_size, LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
padding=pad, dropout=dropout)) padding=(kernel_size - 1), dropout=dropout)
)
self.attention.append(AttentionLayer(out_channels, embed_dim) self.attention.append(AttentionLayer(out_channels, embed_dim)
if attention[i] else None) if attention[i] else None)
in_channels = out_channels in_channels = out_channels
...@@ -254,8 +257,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -254,8 +257,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x, incremental_state) x = conv(x, incremental_state)
if incremental_state is None:
x = conv.remove_future_timesteps(x)
x = F.glu(x, dim=2) x = F.glu(x, dim=2)
# attention # attention
......
...@@ -6,18 +6,10 @@ ...@@ -6,18 +6,10 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
from torch.autograd import Function
from torch.nn.modules.utils import _single from torch.nn.modules.utils import _single
from fairseq import utils from fairseq import utils
try:
from fairseq import temporal_convolution_tbc
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing temporal_convolution_tbc, run `python setup.py install`\n')
raise e
class ConvTBC(torch.nn.Module): class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel) """1D convolution over an input of shape (time x batch x channel)
...@@ -25,23 +17,19 @@ class ConvTBC(torch.nn.Module): ...@@ -25,23 +17,19 @@ class ConvTBC(torch.nn.Module):
The implementation uses gemm to perform the convolution. This implementation The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes. is faster than cuDNN for small kernel sizes.
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, def __init__(self, in_channels, out_channels, kernel_size, padding=0):
padding=0):
super(ConvTBC, self).__init__() super(ConvTBC, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = _single(kernel_size) self.kernel_size = _single(kernel_size)
self.stride = _single(stride)
self.padding = _single(padding) self.padding = _single(padding)
assert self.stride == (1,)
self.weight = torch.nn.Parameter(torch.Tensor( self.weight = torch.nn.Parameter(torch.Tensor(
self.kernel_size[0], in_channels, out_channels)) self.kernel_size[0], in_channels, out_channels))
self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input): def forward(self, input):
return ConvTBCFunction.apply( return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0])
input.contiguous(), self.weight, self.bias, self.padding[0])
def __repr__(self): def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
...@@ -50,57 +38,3 @@ class ConvTBC(torch.nn.Module): ...@@ -50,57 +38,3 @@ class ConvTBC(torch.nn.Module):
s += ', bias=False' s += ', bias=False'
s += ')' s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
class ConvTBCFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, pad):
input_size = input.size()
weight_size = weight.size()
kernel_size = weight_size[0]
output = input.new(
input_size[0] - kernel_size + 1 + int(pad * 2),
input_size[1],
weight_size[2])
ctx.input_size = input_size
ctx.weight_size = weight_size
ctx.save_for_backward(input, weight)
temporal_convolution_tbc.TemporalConvolutionTBC_forward(
input.type().encode('utf-8'),
input,
output,
weight,
bias)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = grad_output.data.contiguous()
grad_input = grad_output.new(ctx.input_size).zero_()
grad_weight = grad_output.new(ctx.weight_size).zero_()
grad_bias = grad_output.new(ctx.weight_size[2])
temporal_convolution_tbc.TemporalConvolutionTBC_backward(
input.type().encode('utf-8'),
grad_output,
grad_input,
grad_weight,
grad_bias,
input,
weight)
grad_input = utils.volatile_variable(grad_input)
grad_weight = utils.volatile_variable(grad_weight)
grad_bias = utils.volatile_variable(grad_bias)
return grad_input, grad_weight, grad_bias, None
def conv_tbc(input, weight, bias=None, stride=1, padding=0):
return ConvTBCFunction.apply(
input.contiguous(), weight, bias, padding[0])
...@@ -18,6 +18,7 @@ class LinearizedConvolution(ConvTBC): ...@@ -18,6 +18,7 @@ class LinearizedConvolution(ConvTBC):
At training time, this module uses ConvTBC, which is an optimized version At training time, this module uses ConvTBC, which is an optimized version
of Conv1d. At inference time, it optimizes incremental generation (i.e., of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers. one time step at a time) by replacing the convolutions with linear layers.
Note that the input order changes from training to inference.
""" """
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
...@@ -27,14 +28,20 @@ class LinearizedConvolution(ConvTBC): ...@@ -27,14 +28,20 @@ class LinearizedConvolution(ConvTBC):
def forward(self, input, incremental_state=None): def forward(self, input, incremental_state=None):
""" """
Input: Time x Batch x Channel. Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
Args: Args:
incremental_state: Used to buffer signal; if not None, then input is incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state. between time steps, call reorder_incremental_state.
""" """
if incremental_state is None: if incremental_state is None:
return super().forward(input) output = super().forward(input)
if self.kernel_size[0] > 1 and self.padding[0] > 0:
# remove future timesteps added by padding
output = output[:-self.padding[0], :, :]
return output
# reshape weight # reshape weight
weight = self._get_linearized_weight() weight = self._get_linearized_weight()
...@@ -57,12 +64,6 @@ class LinearizedConvolution(ConvTBC): ...@@ -57,12 +64,6 @@ class LinearizedConvolution(ConvTBC):
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
def remove_future_timesteps(self, x):
"""Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None: if input_buffer is not None:
......
...@@ -5,12 +5,9 @@ ...@@ -5,12 +5,9 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
from setuptools import setup, find_packages, Extension from setuptools import setup, find_packages, Extension
from setuptools.command.build_py import build_py
import sys import sys
from torch.utils.ffi import create_extension
if sys.version_info < (3,): if sys.version_info < (3,):
...@@ -25,6 +22,7 @@ with open('LICENSE') as f: ...@@ -25,6 +22,7 @@ with open('LICENSE') as f:
with open('requirements.txt') as f: with open('requirements.txt') as f:
reqs = f.read() reqs = f.read()
bleu = Extension( bleu = Extension(
'fairseq.libbleu', 'fairseq.libbleu',
sources=[ sources=[
...@@ -34,23 +32,6 @@ bleu = Extension( ...@@ -34,23 +32,6 @@ bleu = Extension(
extra_compile_args=['-std=c++11'], extra_compile_args=['-std=c++11'],
) )
conv_tbc = create_extension(
'fairseq.temporal_convolution_tbc',
relative_to='fairseq',
headers=['fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h'],
sources=['fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp'],
define_macros=[('WITH_CUDA', None)],
with_cuda=True,
extra_compile_args=['-std=c++11'],
source_extension='.cpp',
)
class build_py_hook(build_py):
def run(self):
conv_tbc.build()
build_py.run(self)
setup( setup(
name='fairseq', name='fairseq',
...@@ -62,13 +43,4 @@ setup( ...@@ -62,13 +43,4 @@ setup(
packages=find_packages(), packages=find_packages(),
ext_modules=[bleu], ext_modules=[bleu],
test_suite='tests', test_suite='tests',
# build and install PyTorch extensions
package_data={
'fairseq': ['temporal_convolution_tbc/*.so'],
},
include_package_data=True,
cmdclass={
'build_py': build_py_hook,
},
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment