Unverified Commit d86d1b09 authored by Nan Zheng's avatar Nan Zheng Committed by GitHub
Browse files

Initial check-in of the transducer extensions (#1069)

* Initial check-in of the transducer extension.

* Added more comments to help explain the code

* Corrected minor typos

* 1. Renamed variable in tests to match the extension
2. Disabled ninja build option
parent e2083df5
#include <torch/extension.h>
#include <ATen/Functions.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput);
torch::Tensor transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_forward(
f,
g,
fLen,
gLen,
batchOffset,
packedBatch,
opt,
packOutput,
tileSize);
}
std::vector<torch::Tensor> transducer_joint_backward(
torch::Tensor grad,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput) {
CHECK_INPUT(grad);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
grad,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)");
m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)");
}
\ No newline at end of file
This diff is collapsed.
#include <torch/extension.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput);
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput);
std::vector<torch::Tensor> transducer_loss_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput
) {
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_forward(
x,
label,
fLen,
yLen,
batchOffset,
maxFLen,
blankIdx,
opt,
packedInput);
}
torch::Tensor transducer_loss_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(lossGrad);
CHECK_INPUT(alpha);
CHECK_INPUT(beta);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_backward(
x,
lossGrad,
alpha,
beta,
fLen,
yLen,
label,
batchOffset,
maxFLen,
blankIdx,
opt,
fuseSoftmaxBackward,
packedInput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)");
m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)");
}
This diff is collapsed.
import torch
import unittest
from apex.contrib.transducer import TransducerJoint
import transducer_ref
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed
def run_transducer_joint(self, for_vector_kernel, pack_output):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output)
my_joint= TransducerJoint(pack_output=pack_output)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=False)
def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False)
def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True)
def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import unittest
from apex.contrib.transducer import TransducerLoss
import transducer_ref
class TransducerLossTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t):
self.B = 5
T_min = 23
T_max = 51
U_min = 12
U_max = 25
V = 16
self.blank_idx = V - 1
device = "cuda"
self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True,
device=device)
self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1
self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)
# Generate reference
x_ref = self.x_tst.data.clone()
x_ref.requires_grad = True
loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0)
_, _, self.grad_ref, self.loss_ref \
= transducer_ref.transducer_loss_reference( x=x_ref,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
loss_grad=loss_grad)
def _pack(self, x):
list_x = []
for b in range(self.B):
list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0)
return x_packed, batch_offset
def _unpack(self, x):
x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1),
dtype=x.dtype, device=x.device)
for b in range(self.B):
my_batch_offset = 0 if b == 0 else self.batch_offset[b-1]
my_f_len = self.f_len[b]
my_g_len = self.y_len[b] + 1
for t in range(my_f_len):
for u in range(my_g_len):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input):
self.gen_input(scalar_t)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input)
if not packed_input:
loss_tst = my_loss( x=self.x_tst,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx)
loss_tst.mean().backward()
grad_tst = self.x_tst.grad
else:
loss_tst = my_loss( x=self.x_tst_packed,
label=self.y,
f_len=self.f_len,
y_len=self.y_len,
blank_idx=self.blank_idx,
batch_offset=self.batch_offset,
max_f_len=max(self.f_len))
loss_tst.mean().backward()
grad_tst_packed = self.x_tst_packed.grad
grad_tst = self._unpack(grad_tst_packed)
return loss_tst, grad_tst
def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False,
packed_input=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False,
packed_input=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import torch
import numpy as np
import pdb
def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
def log_sum_exp(a, b):
if (a >= b):
return a + torch.log(1 + torch.exp(b-a))
else:
return b + torch.log(1 + torch.exp(a-b))
def forward_alpha(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.size()
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
alpha[b, 0, 0] = 0
for t in range(1, f_len[b]):
alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]
for u in range(1, y_len[b]+1):
alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]
for t in range(1, f_len[b]):
for u in range(1, y_len[b]+1):
curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]
next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]
alpha[b, t, u] = log_sum_exp(curr_, next_)
return alpha
def forward_beta(x, label, f_len, y_len, blank_idx):
B, T, U, V = x.shape
acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
for b in range(B):
beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]
for t in range(f_len[b]-2, -1, -1):
beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx]
for u in range(y_len[b]-1, -1, -1):
beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]
for t in range(f_len[b]-2, -1, -1):
for u in range(y_len[b]-1, -1, -1):
curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx]
next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]
beta[b, t, u] = log_sum_exp(curr_, next_)
return beta
def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):
grad = torch.zeros_like(x)
B, T, U, V = x.size()
for b in range(B):
common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]
# next
for u in range(y_len[b]):
grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u]
+ beta[b, :f_len[b], u+1]
+ x[b, :f_len[b], u, label[b, u]])
# current
grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \
= -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1]
+ beta[b, 1:f_len[b], :y_len[b]+1]
+ x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])
grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]
+ x[b, f_len[b]-1, y_len[b], blank_idx])
return grad
x_log = torch.nn.functional.log_softmax(x, dim=-1)
alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)
beta = forward_beta(x_log, label, f_len, y_len, blank_idx)
grad = backward(x_log, label, f_len, y_len, alpha, beta,
loss_grad, blank_idx)
x_log.backward(grad)
loss = -beta[:, 0, 0]
loss = loss.to(x.dtype)
return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output):
B, T, H = f.size()
U = g.size(1)
f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand
h.backward(h_grad)
if pack_output == False:
# intentionally set don't-care region to -1 to test if transducer joint
# write these regions to avoid NaN and inf
for b in range(B):
h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad
# packing
list_to_pack = []
for b in range(B):
list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))
h_packed = torch.cat(list_to_pack)
return h_packed, f.grad, g.grad
from .transducer import TransducerJoint
from .transducer import TransducerLoss
\ No newline at end of file
import torch
import transducer_loss_cuda
import transducer_joint_cuda
class TransducerJoint(torch.nn.Module):
"""Transducer joint
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4)
"""
def __init__(self, pack_output=False, opt=1, fwd_tile_size=4):
super(TransducerJoint, self).__init__()
self.pack_output = pack_output
self.opt = opt
self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0)
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
"""Forward operation of transducer joint
Arguments:
f (tensor): transcription vector from encode block of shape (B, T, H).
g (tensor): prediction vector form predict block of shape (B, U, H).
f_len (tensor): length of transcription vector for each batch.
g_len (tensor): length of prediction vector minus 1 for each batch.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the results. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*g_len, dim=0)
This argument is required if pack_output == True, and is ignored if
pack_output == False. (default: None)
packed_batch (int, optional): the batch size after packing. This argument is
ignored if pack_output == False. (default: 0)
"""
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, my_batch_offset,
packed_batch, self.opt, self.fwd_tile_size)
class TransducerLoss(torch.nn.Module):
"""Transducer loss
Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural
Networks
Arguments:
fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with
softmax. (default: True)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized
algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)
packed_input (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
"""
def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):
super(TransducerLoss, self).__init__()
self.fuse_softmax_backward = fuse_softmax_backward
self.opt = opt
self.packed_input = packed_input
self.dummy_batch_offset = torch.empty(0)
def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None,
debug_list=None):
"""Forward operation of transducer joint
Arguments:
x (tensor): input tensor to the loss function with a shape of (B, T, U, H).
label (tensor): labels for the input data.
f_len (tensor): lengths of the inputs in the time dimension for each batch.
y_len (tensor): lengths of the labels for each batch.
blank_idx (int): index for the null symbol.
batch_offset (tensor, optional): tensor containing the offset of each batch
in the input. For example, batch offset can be obtained from:
batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
max_f_len (int, optional): maximum length of the input in the time dimension.
For example, it can be obtained as
max_f_len = max(f_len)
This argument is required if packed_input == True, and is ignored if
packed_input == False. (default: None)
(default: None)
debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated
in the forward operation will be attached to this list for debug purpose.
(default: None)
"""
if self.packed_input:
if batch_offset is None or max_f_len is None:
raise Exception("Please specify batch_offset and max_f_len when packing is \
enabled")
my_batch_offset = batch_offset
my_max_f_len = max_f_len
else:
my_batch_offset = self.dummy_batch_offset
my_max_f_len = x.size(1)
return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len,
blank_idx, self.fuse_softmax_backward, debug_list,
self.opt, self.packed_input)
class TransducerLossFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx,
fuse_softmax_backward, debug_list, opt, packed_input):
if fuse_softmax_backward == False:
with torch.enable_grad():
x = torch.nn.functional.log_softmax(x, dim=-1)
else:
x = torch.nn.functional.log_softmax(x, dim=-1)
alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset,
max_f_len, blank_idx, opt, packed_input)
if debug_list == []:
debug_list += [alpha, beta]
ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)
ctx.blank_idx = blank_idx
ctx.fuse_softmax_backward = fuse_softmax_backward
ctx.opt = opt
ctx.packed_input = packed_input
ctx.max_f_len = max_f_len
return loss
@staticmethod
def backward(ctx, loss_grad):
x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors
x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label,
batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt,
ctx.fuse_softmax_backward, ctx.packed_input)
if ctx.fuse_softmax_backward == False:
x_grad = x.backward(x_grad)
return x_grad, None, None, None, None, None, None, None, None, None, None
class TransducerJointFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, f, g, f_len, g_len, pack_output, batch_offset, packed_batch, opt,
fwd_tile_size):
h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt,
pack_output, fwd_tile_size)
ctx.save_for_backward(f_len, g_len, batch_offset)
ctx.pack_output = pack_output
ctx.max_f_len = f.size(1)
ctx.max_g_len = g.size(1)
return h
@staticmethod
def backward(ctx, loss_grad):
f_len, g_len, batch_offset = ctx.saved_tensors
f_grad, g_grad = transducer_joint_cuda.backward(loss_grad, f_len, g_len, batch_offset,
ctx.max_f_len, ctx.max_g_len,
ctx.pack_output)
return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None
......@@ -453,6 +453,31 @@ if "--fast_multihead_attn" in sys.argv:
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
if "--transducer" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--transducer")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='transducer_joint_cuda',
sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
setup(
name='apex',
version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment