Commit 42b5fc9a authored by dbaranchuk's avatar dbaranchuk
Browse files

add memory effcient backward option

parent 843ad063
import operator import operator
import torch import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from dataclasses import dataclass from dataclasses import dataclass
...@@ -187,6 +188,8 @@ class MatmulLtState: ...@@ -187,6 +188,8 @@ class MatmulLtState:
use_pool = False use_pool = False
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
memory_efficient_backward = False
def reset_grads(self): def reset_grads(self):
self.CB = None self.CB = None
self.CxB = None self.CxB = None
...@@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx) outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = ( state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0) (outliers * state.SCB.view(-1, 1) / 127.0)
...@@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function):
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape)) return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty: if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, req_gradBias = ctx.req_grads req_gradA, req_gradB, req_gradBias = ctx.req_grads
assert not req_gradB, "TODO: support weight updates as well" CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state state = ctx.state
# Cast grad_output to fp16 # Cast grad_output to fp16
...@@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A = grad_B = grad_bias = None grad_A = grad_B = grad_bias = None
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA: if req_gradA:
if state.CBt:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
elif state.CB:
CB = state.CB.half() CB = state.CB.half()
SCB = (state.SCB.unsqueeze(1) / 127.0).half() SCB = (state.SCB.unsqueeze(1) / 127.0).half()
CB *= SCB CB *= SCB
grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
else:
raise Exception('State must contain either CBt or CB matrix')
if req_gradBias: if req_gradBias:
grad_bias = grad_output.sum(0) grad_bias = grad_output.sum(0)
...@@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
matmul = MatMul8bitLt.apply
def matmul( def matmul(
A: tensor, A: tensor,
B: tensor, B: tensor,
......
...@@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear): ...@@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear):
has_fp16_weights=True, has_fp16_weights=True,
threshold=0.0, threshold=0.0,
index=None, index=None,
memory_efficient_backward=False
): ):
super(Linear8bitLt, self).__init__( super(Linear8bitLt, self).__init__(
input_features, output_features, bias input_features, output_features, bias
...@@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear): ...@@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear):
self.state.threshold = threshold self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
...@@ -255,9 +257,15 @@ class Linear8bitLt(nn.Linear): ...@@ -255,9 +257,15 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights and self.state.CxB is not None: if not self.state.has_fp16_weights:
# In this version, we convert 8-bit row major to turing/ampere format at each inference pass if not self.state.memory_efficient_backward and self.state.CB is not None:
# Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. # we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB del self.state.CxB
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment