import torch import mhalib ########################################################################################### class FastSoftmaxFunction(torch.autograd.Function): @staticmethod def forward(cxt, input, dim, batch, seqlen, heads, stream, sync, timers): if timers: timers['start_fprop'].record() mhalib.FastSoftmaxFprop(input, batch, seqlen, heads, stream, sync) if timers: timers['stop_fprop'].record() cxt.save_for_backward(input,seqlen) cxt.dim = dim cxt.batch = batch cxt.heads = heads cxt.stream = stream cxt.sync = sync cxt.timers = timers return input @staticmethod def backward(cxt, grad_output): output, seqlen, = cxt.saved_tensors dim = cxt.dim batch = cxt.batch heads = cxt.heads if cxt.timers: cxt.timers['start_dgrad'].record() mhalib.FastSoftmaxBprop(output, grad_output, batch, seqlen, heads, cxt.stream, cxt.sync) if cxt.timers: cxt.timers['stop_dgrad'].record() return grad_output, None, None, None, None, None, None, None class FastSoftmax(torch.nn.Module): def __init__(self, dim=None, stream=True, sync=True, timer=False): super(FastSoftmax, self).__init__() self.dim = dim self.stream = stream self.sync = sync if timer: self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True), 'start_dgrad':torch.cuda.Event(enable_timing=True), 'stop_fprop':torch.cuda.Event(enable_timing=True), 'stop_dgrad':torch.cuda.Event(enable_timing=True)} else: self.timers = None def forward(self, input, batch, seqlen, heads): return FastSoftmaxFunction.apply(input, self.dim, batch, seqlen, heads, self.stream, self.sync, self.timers) ########################################################################################### class FastMaskSoftmaxFunction(torch.autograd.Function): @staticmethod def forward(cxt, input, mask, dim, batch, seqlen, heads, stream, sync, timers): if timers: timers['start_fprop'].record() mhalib.FastMaskSoftmaxFprop(input, mask, batch, seqlen, heads, stream, sync) if timers: timers['stop_fprop'].record() cxt.save_for_backward(input,seqlen) cxt.dim = dim cxt.batch = batch cxt.heads = heads cxt.stream = stream cxt.sync = sync cxt.timers = timers return input @staticmethod def backward(cxt, grad_output): output, seqlen, = cxt.saved_tensors dim = cxt.dim batch = cxt.batch heads = cxt.heads if cxt.timers: cxt.timers['start_dgrad'].record() mhalib.FastSoftmaxBprop(output, grad_output, batch, seqlen, heads, cxt.stream, cxt.sync) if cxt.timers: cxt.timers['stop_dgrad'].record() return grad_output, None, None, None, None, None, None, None, None, None, None, None class FastMaskSoftmax(torch.nn.Module): def __init__(self, dim=None, stream=True, sync=True, timer=False): super(FastMaskSoftmax, self).__init__() self.dim = dim self.stream = stream self.sync = sync if timer: self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True), 'start_dgrad':torch.cuda.Event(enable_timing=True), 'stop_fprop':torch.cuda.Event(enable_timing=True), 'stop_dgrad':torch.cuda.Event(enable_timing=True)} else: self.timers = None def forward(self, input, mask, batch, seqlen, heads): return FastMaskSoftmaxFunction.apply(input, mask, self.dim, batch, seqlen, heads, self.stream, self.sync, self.timers) ########################################################################################### class FastMaskSoftmaxDropoutFunction(torch.autograd.Function): @staticmethod def forward(cxt, input, mask, dim, batch, seqlen, heads, dropout_prob, stream, sync, timers, is_training): if timers: timers['start_fprop'].record() output, dropout_mask, = mhalib.FastMaskSoftmaxDropoutFprop(input, mask, batch, seqlen, heads, dropout_prob, stream, sync, is_training) if timers: timers['stop_fprop'].record() cxt.save_for_backward(input,dropout_mask,seqlen) cxt.dim = dim cxt.batch = batch cxt.heads = heads cxt.dropout_prob = dropout_prob cxt.stream = stream cxt.sync = sync cxt.timers = timers return output @staticmethod def backward(cxt, grad_output): output, dropout_mask, seqlen, = cxt.saved_tensors dim = cxt.dim batch = cxt.batch heads = cxt.heads dropout_prob = cxt.dropout_prob if cxt.timers: cxt.timers['start_dgrad'].record() mhalib.FastMaskSoftmaxDropoutBprop(output, grad_output, dropout_mask, batch, seqlen, heads, dropout_prob, cxt.stream, cxt.sync) if cxt.timers: cxt.timers['stop_dgrad'].record() return grad_output, None, None, None, None, None, None, None, None, None, None, None, None, None class FastMaskSoftmaxDropout(torch.nn.Module): def __init__(self, dim=None, dropout_prob=None, stream=True, sync=True, timer=False): super(FastMaskSoftmaxDropout, self).__init__() self.dim = dim self.dropout_prob = dropout_prob self.stream = stream self.sync = sync if timer: self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True), 'start_dgrad':torch.cuda.Event(enable_timing=True), 'stop_fprop':torch.cuda.Event(enable_timing=True), 'stop_dgrad':torch.cuda.Event(enable_timing=True)} else: self.timers = None def forward(self, input, mask, batch, seqlen, heads, is_training): return FastMaskSoftmaxDropoutFunction.apply(input, mask, self.dim, batch, seqlen, heads, self.dropout_prob, self.stream, self.sync, self.timers, is_training) ###########################################################################################