Commit 56f9ec3c authored by James Reed's avatar James Reed Committed by Myle Ott
Browse files

Use ATen built-in conv_tbc method (#66)

Remove custom ConvTBC code
parent 6e4d370a
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdio.h>
#include <string.h>
#include <stdexcept>
#include <ATen/ATen.h>
using at::Tensor;
extern THCState* state;
at::Type& getDataType(const char* dtype) {
if (strcmp(dtype, "torch.cuda.FloatTensor") == 0) {
return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat);
} else if (strcmp(dtype, "torch.cuda.DoubleTensor") == 0) {
return at::getType(at::kCUDA, at::kDouble);
} else if (strcmp(dtype, "torch.DoubleTensor") == 0) {
return at::getType(at::kCPU, at::kDouble);
} else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
}
}
inline at::Tensor t(at::Type& type, void* i) {
return type.unsafeTensorFromTH(i, true);
}
void TemporalConvolutionTBC_forward(
const char* dtype,
void* _input,
void* _output,
void* _weight,
void* _bias)
{
auto& type = getDataType(dtype);
Tensor input = t(type, _input);
Tensor output = t(type, _output);
Tensor weight = t(type, _weight);
Tensor bias = t(type, _bias);
auto input_size = input.sizes();
auto output_size = output.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
// input * weights + bias -> output_features
output.copy_(bias.expand(output.sizes()));
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// Note: gemm assumes column-major matrices
// input is l*m (row-major)
// weight is m*r (row-major)
// output is l*r (row-major)
if (t > 0) {
auto W = weight[k];
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
O.addmm_(I, W);
}
}
}
void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight)
{
auto& type = getDataType(dtype);
Tensor dOutput = t(type, _dOutput);
Tensor dInput = t(type, _dInput);
Tensor dWeight = t(type, _dWeight);
Tensor dBias = t(type, _dBias);
Tensor input = t(type, _input);
Tensor weight = t(type, _weight);
auto input_size = input.sizes();
auto output_size = dOutput.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// dOutput * T(weight) -> dInput
if (t > 0) {
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
dI.addmm_(dO, weight[k].t());
}
}
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// T(input) * dOutput -> dWeight
if (t > 0) {
auto dW = dWeight[k];
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
dW.addmm_(I, dO);
}
}
auto tmp = dOutput.sum(0, false);
dBias.copy_(tmp.sum(0));
}
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
void TemporalConvolutionTBC_forward(
const char* dtype,
void* input,
void* output,
void* weight,
void* bias);
void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight);
......@@ -91,12 +91,12 @@ class FConvEncoder(FairseqEncoder):
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
for (out_channels, kernel_size) in convolutions:
pad = (kernel_size - 1) / 2
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, padding=pad,
dropout=dropout))
ConvTBC(in_channels, out_channels * 2, kernel_size,
dropout=dropout)
)
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
......@@ -116,6 +116,9 @@ class FConvEncoder(FairseqEncoder):
for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training)
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x)
x = F.glu(x, dim=2)
x = (x + residual) * math.sqrt(0.5)
......@@ -211,12 +214,12 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.convolutions = nn.ModuleList()
self.attention = nn.ModuleList()
for i, (out_channels, kernel_size) in enumerate(convolutions):
pad = kernel_size - 1
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.convolutions.append(
LinearizedConv1d(in_channels, out_channels * 2, kernel_size,
padding=pad, dropout=dropout))
padding=(kernel_size - 1), dropout=dropout)
)
self.attention.append(AttentionLayer(out_channels, embed_dim)
if attention[i] else None)
in_channels = out_channels
......@@ -254,8 +257,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x, incremental_state)
if incremental_state is None:
x = conv.remove_future_timesteps(x)
x = F.glu(x, dim=2)
# attention
......
......@@ -6,18 +6,10 @@
# can be found in the PATENTS file in the same directory.
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _single
from fairseq import utils
try:
from fairseq import temporal_convolution_tbc
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing temporal_convolution_tbc, run `python setup.py install`\n')
raise e
class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel)
......@@ -25,23 +17,19 @@ class ConvTBC(torch.nn.Module):
The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0):
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _single(kernel_size)
self.stride = _single(stride)
self.padding = _single(padding)
assert self.stride == (1,)
self.weight = torch.nn.Parameter(torch.Tensor(
self.kernel_size[0], in_channels, out_channels))
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input):
return ConvTBCFunction.apply(
input.contiguous(), self.weight, self.bias, self.padding[0])
return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0])
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
......@@ -50,57 +38,3 @@ class ConvTBC(torch.nn.Module):
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class ConvTBCFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, pad):
input_size = input.size()
weight_size = weight.size()
kernel_size = weight_size[0]
output = input.new(
input_size[0] - kernel_size + 1 + int(pad * 2),
input_size[1],
weight_size[2])
ctx.input_size = input_size
ctx.weight_size = weight_size
ctx.save_for_backward(input, weight)
temporal_convolution_tbc.TemporalConvolutionTBC_forward(
input.type().encode('utf-8'),
input,
output,
weight,
bias)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_output = grad_output.data.contiguous()
grad_input = grad_output.new(ctx.input_size).zero_()
grad_weight = grad_output.new(ctx.weight_size).zero_()
grad_bias = grad_output.new(ctx.weight_size[2])
temporal_convolution_tbc.TemporalConvolutionTBC_backward(
input.type().encode('utf-8'),
grad_output,
grad_input,
grad_weight,
grad_bias,
input,
weight)
grad_input = utils.volatile_variable(grad_input)
grad_weight = utils.volatile_variable(grad_weight)
grad_bias = utils.volatile_variable(grad_bias)
return grad_input, grad_weight, grad_bias, None
def conv_tbc(input, weight, bias=None, stride=1, padding=0):
return ConvTBCFunction.apply(
input.contiguous(), weight, bias, padding[0])
......@@ -18,6 +18,7 @@ class LinearizedConvolution(ConvTBC):
At training time, this module uses ConvTBC, which is an optimized version
of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
Note that the input order changes from training to inference.
"""
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
......@@ -27,14 +28,20 @@ class LinearizedConvolution(ConvTBC):
def forward(self, input, incremental_state=None):
"""
Input: Time x Batch x Channel.
Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
Args:
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state.
"""
if incremental_state is None:
return super().forward(input)
output = super().forward(input)
if self.kernel_size[0] > 1 and self.padding[0] > 0:
# remove future timesteps added by padding
output = output[:-self.padding[0], :, :]
return output
# reshape weight
weight = self._get_linearized_weight()
......@@ -57,12 +64,6 @@ class LinearizedConvolution(ConvTBC):
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
def remove_future_timesteps(self, x):
"""Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
......
......@@ -5,12 +5,9 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from setuptools import setup, find_packages, Extension
from setuptools.command.build_py import build_py
import sys
from torch.utils.ffi import create_extension
if sys.version_info < (3,):
......@@ -25,6 +22,7 @@ with open('LICENSE') as f:
with open('requirements.txt') as f:
reqs = f.read()
bleu = Extension(
'fairseq.libbleu',
sources=[
......@@ -34,23 +32,6 @@ bleu = Extension(
extra_compile_args=['-std=c++11'],
)
conv_tbc = create_extension(
'fairseq.temporal_convolution_tbc',
relative_to='fairseq',
headers=['fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h'],
sources=['fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp'],
define_macros=[('WITH_CUDA', None)],
with_cuda=True,
extra_compile_args=['-std=c++11'],
source_extension='.cpp',
)
class build_py_hook(build_py):
def run(self):
conv_tbc.build()
build_py.run(self)
setup(
name='fairseq',
......@@ -62,13 +43,4 @@ setup(
packages=find_packages(),
ext_modules=[bleu],
test_suite='tests',
# build and install PyTorch extensions
package_data={
'fairseq': ['temporal_convolution_tbc/*.so'],
},
include_package_data=True,
cmdclass={
'build_py': build_py_hook,
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment