Commit 1753aa04 authored by dbaranchuk's avatar dbaranchuk
Browse files

refactoring

parent 8ae9bb23
...@@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx] subA = A[:, idx]
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
elif state.CxB is None: else:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions if state.CxB is None:
# we also need to convert it to the turing/ampere format # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
...@@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA: if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None and state.has_fp16_weights: if state.CxBt is None:
CBt = state.CBt if state.has_fp16_weights:
elif state.CxBt is None: CBt = state.CBt
assert state.CBt is None else:
CB = state.CB.half() # Restore CBt from CB
SCB = state.SCB.unsquezee(1).half() assert state.CBt is None, "CBt should not be stored in state"
SCBt = state.SCBt.unsquezee(1).half() CB = state.CB.half()
Bt = (CB * SCB).t().contiguous() SCB = state.SCB.unsquezee(1).half()
CBt = (Bt / SCBt).t().to(torch.int8) SCBt = state.SCBt.unsquezee(1).half()
Bt = (CB * SCB).t().contiguous()
CxBt, SBt = F.transform( CBt = (Bt / SCBt).t().to(torch.int8)
CBt, to_order=formatB, transpose=True
) # intentionally, do not store CxBt into state
CxBt, SBt = F.transform(
CBt, to_order=formatB, transpose=True
)
else:
CxBt = state.CxBt
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment