import torch import mhalib ########################################################################################### class Bmm2Function(torch.autograd.Function): @staticmethod def forward(ctx, batch1, batch2, seqlen, batch, maxseqlen, heads, embed, sync, stream): ctx.save_for_backward(batch1, batch2, seqlen) ctx.batch = batch ctx.maxseqlen = maxseqlen ctx.heads = heads ctx.embed = embed ctx.stream = stream ctx.sync = sync ntokens = seqlen.sum().item() ctx.ntokens = ntokens output = torch.empty([ntokens,heads,embed], device="cuda", dtype=torch.float16) mhalib.FastBmm2Fprop(batch2.flatten().contiguous(), batch1.flatten().contiguous(), output, batch, seqlen, heads, embed, False, False, stream, sync) return output[:ntokens] @staticmethod def backward(ctx, grad_output): batch1, batch2, seqlen = ctx.saved_tensors batch = ctx.batch maxseqlen = ctx.maxseqlen heads = ctx.heads embed = ctx.embed ntokens = ctx.ntokens ntokens2 = 0 for i in range(batch): ntokens2 += seqlen[i]*seqlen[i] grad_batch1 = torch.empty([ntokens2*heads], device="cuda", dtype=torch.float16) grad_batch2 = torch.empty([ntokens,heads*embed], device="cuda", dtype=torch.float16) mhalib.FastBmm2Dgrad1(batch2.flatten().contiguous(), grad_output, grad_batch1, batch, seqlen, heads, embed, False, False, ctx.stream, ctx.sync) mhalib.FastBmm2Dgrad2(grad_output, batch1, grad_batch2, batch, seqlen, heads, embed, False, False, ctx.stream, ctx.sync) return grad_batch1[:ntokens2*heads], grad_batch2[:ntokens], None, None, None, None, None, None, None class Bmm2(torch.nn.Module): def __init__(self, batch, seqlen, heads, embed, stream=True, sync=True): super(Bmm2, self).__init__() self.heads = heads self.embed = embed self.maxseqlen = seqlen self.stream = stream self.sync = sync def forward(self, batch1, batch2, batch, seqlen): return Bmm2Function.apply(batch1, batch2, seqlen, batch, self.maxseqlen, self.heads, self.embed, self.stream, self.sync) ########################################################################################### class Bmm2StridedFunction(torch.autograd.Function): @staticmethod def forward(ctx, batch1, mixed, seqlen, batch, maxseqlen, heads, embed, stream, sync, timers): ctx.save_for_backward(batch1, mixed, seqlen) ctx.batch = batch ctx.maxseqlen = maxseqlen ctx.heads = heads ctx.embed = embed ctx.stream = stream ctx.sync = sync ctx.timers = timers ntokens = seqlen.sum().item() ctx.ntokens = ntokens output = torch.empty([ntokens,heads,embed], device="cuda", dtype=torch.float16) if timers: timers['start_fprop'].record() mhalib.FastBmm2Fprop(mixed, batch1, output, batch, seqlen, heads, embed, False, True, stream, sync) if timers: timers['stop_fprop'].record() return output[:ntokens] @staticmethod def backward(ctx, grad_output): batch1, mixed, seqlen = ctx.saved_tensors batch = ctx.batch maxseqlen = ctx.maxseqlen heads = ctx.heads embed = ctx.embed ntokens = ctx.ntokens ntokens2 = 0 for i in range(batch): ntokens2 += seqlen[i]*seqlen[i] grad_batch1 = torch.empty(ntokens2*heads, device="cuda", dtype=torch.float16) grad_mixed = torch.empty([ntokens,heads*3*embed], device="cuda", dtype=torch.float16) if ctx.timers: ctx.timers['start_dgrad'].record() mhalib.FastBmm2Dgrad1(mixed, grad_output, grad_batch1, batch, seqlen, heads, embed, False, True, ctx.stream, ctx.sync) if ctx.timers: ctx.timers['stop_dgrad'].record() if ctx.timers: ctx.timers['start_wgrad'].record() mhalib.FastBmm2Dgrad2(grad_output, batch1, grad_mixed, batch, seqlen, heads, embed, False, True, ctx.stream, ctx.sync) if ctx.timers: ctx.timers['stop_wgrad'].record() return grad_batch1[:ntokens2*heads], grad_mixed[:ntokens], None, None, None, None, None, None, None, None class Bmm2Strided(torch.nn.Module): def __init__(self, batch, seqlen, heads, embed, stream=True, sync=True, timer=False): super(Bmm2Strided, self).__init__() self.heads = heads self.embed = embed self.maxseqlen = seqlen self.stream = stream self.sync = sync if timer: self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True), 'start_dgrad':torch.cuda.Event(enable_timing=True), 'start_wgrad':torch.cuda.Event(enable_timing=True), 'stop_fprop':torch.cuda.Event(enable_timing=True), 'stop_dgrad':torch.cuda.Event(enable_timing=True), 'stop_wgrad':torch.cuda.Event(enable_timing=True)} else: self.timers = None def forward(self, batch1, mixed, batch, seqlen): return Bmm2StridedFunction.apply(batch1, mixed, seqlen, batch, self.maxseqlen, self.heads, self.embed, self.stream, self.sync, self.timers) ###########################################################################################