"docs/vscode:/vscode.git/clone" did not exist on "7948db81c58cc8ce3c6070088389b28ff487b02a"
Unverified Commit dcfb6f81 authored by Benjamin Warner's avatar Benjamin Warner Committed by GitHub
Browse files

Initial FSDP Support for QLoRA Finetuning (#970)



This PR adds initial FSDP support for training QLoRA models. It enables basic FSDP and CPU Offload support, with low memory training via FSDP.sync_module_states option unsupported.

This PR builds off of #840 commit 8278fca and BNB FSDP by @TimDettmers and @Titus-von-Koeller.

An example of using this PR to finetune QLoRA models with FSDP can be found in the demo repo: AnswerDotAi/fsdp_qlora.

* Minimal changes for fp32 4bit storage from BNB commit 8278fca

* Params4bit with selectable storage dtype

* possible fix for double quantizing linear weight & quant storage dtype

* minor fixes in Params4bit for peft tests

* remove redundant

* add float16

* update test

* Remove float16 quant cast as there are fp32, bf16, & fp16 quant kernels

---------
Co-authored-by: default avatarKerem Turgutlu <keremturgutlu@gmail.com>
parent 64a28d02
...@@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device) return data.to(device)
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor:
""" """
Quantize tensor A in blocks of 4-bit values. Quantize tensor A in blocks of 4-bit values.
...@@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
out : torch.Tensor out : torch.Tensor
The output tensor (8-bit). The output tensor.
blocksize : int blocksize : int
The blocksize used in quantization. The blocksize used in quantization.
quant_type : str quant_type : str
...@@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
Returns Returns
------- -------
torch.Tensor: torch.Tensor:
The 8-bit tensor with packed 4-bit values. Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int): tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization. The quantization state to undo the quantization.
""" """
...@@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if out is None: if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
...@@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = ...@@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor =
Parameters Parameters
---------- ----------
A : torch.Tensor A : torch.Tensor
The input 8-bit tensor (packed 4-bit values). The input tensor (packed 4-bit values).
quant_state : QuantState quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype. object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor absmax : torch.Tensor
...@@ -1626,7 +1627,7 @@ def gemv_4bit( ...@@ -1626,7 +1627,7 @@ def gemv_4bit(
ldb = ct.c_int32(ldb) ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc) ldc = ct.c_int32(ldc)
if B.dtype == torch.uint8: if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
......
...@@ -141,8 +141,18 @@ class Embedding(torch.nn.Embedding): ...@@ -141,8 +141,18 @@ class Embedding(torch.nn.Embedding):
class Params4bit(torch.nn.Parameter): class Params4bit(torch.nn.Parameter):
def __new__(
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit": cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
quant_state: QuantState = None,
blocksize: int = 64,
compress_statistics: bool = True,
quant_type: str = 'fp4',
quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False
) -> "Params4bit":
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -151,7 +161,10 @@ class Params4bit(torch.nn.Parameter): ...@@ -151,7 +161,10 @@ class Params4bit(torch.nn.Parameter):
self.compress_statistics = compress_statistics self.compress_statistics = compress_statistics
self.quant_type = quant_type self.quant_type = quant_type
self.quant_state = quant_state self.quant_state = quant_state
self.quant_storage = quant_storage
self.bnb_quantized = bnb_quantized
self.data = data self.data = data
self.module = module
return self return self
@classmethod @classmethod
...@@ -162,16 +175,23 @@ class Params4bit(torch.nn.Parameter): ...@@ -162,16 +175,23 @@ class Params4bit(torch.nn.Parameter):
self.blocksize = self.quant_state.blocksize self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type self.quant_type = self.quant_state.quant_type
self.bnb_quantized = True
return self return self
def cuda(self, device): def _quantize(self, device):
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
self.data = w_4bit self.data = w_4bit
self.quant_state = quant_state self.quant_state = quant_state
if self.module is not None:
self.module.quant_state = quant_state
self.bnb_quantized = True
return self return self
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)
@overload @overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
... ...
...@@ -187,8 +207,8 @@ class Params4bit(torch.nn.Parameter): ...@@ -187,8 +207,8 @@ class Params4bit(torch.nn.Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): if (device is not None and device.type == "cuda" and not self.bnb_quantized):
return self.cuda(device) return self._quantize(device)
else: else:
if self.quant_state is not None: if self.quant_state is not None:
self.quant_state.to(device) self.quant_state.to(device)
...@@ -203,12 +223,14 @@ class Params4bit(torch.nn.Parameter): ...@@ -203,12 +223,14 @@ class Params4bit(torch.nn.Parameter):
class Linear4bit(nn.Linear): class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False self.compute_type_is_set = False
self.quant_state = None
self.quant_storage = quant_storage
def set_compute_type(self, x): def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]: if x.dtype in [torch.float32, torch.bfloat16]:
...@@ -243,6 +265,14 @@ class Linear4bit(nn.Linear): ...@@ -243,6 +265,14 @@ class Linear4bit(nn.Linear):
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
if getattr(self, 'quant_state', None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set: if not self.compute_type_is_set:
self.set_compute_type(x) self.set_compute_type(x)
...@@ -261,8 +291,8 @@ class Linear4bit(nn.Linear): ...@@ -261,8 +291,8 @@ class Linear4bit(nn.Linear):
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
class LinearNF4(Linear4bit): class LinearNF4(Linear4bit):
...@@ -276,8 +306,8 @@ class LinearNF4(Linear4bit): ...@@ -276,8 +306,8 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' '''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
......
...@@ -2370,7 +2370,8 @@ def test_normal_map_tree(): ...@@ -2370,7 +2370,8 @@ def test_normal_map_tree():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, double_quant, kind): @pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]: for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]: #for dim in [4*1024]:
#for dim in [1*16]: #for dim in [1*16]:
...@@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind): ...@@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
A = torch.randn(1, dim, dtype=dtype, device='cuda') A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state) C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True A.requires_grad = True
......
...@@ -8,13 +8,19 @@ import torch ...@@ -8,13 +8,19 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quant_type, compress_statistics, bias", "quant_type, compress_statistics, bias, quant_storage",
list(product(["nf4", "fp4"], [False, True], [False, True])), list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),
) )
def test_linear_serialization(quant_type, compress_statistics, bias): def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
original_dtype = torch.float16 original_dtype = torch.float16
compute_dtype = None compute_dtype = None
device = "cuda" device = "cuda"
...@@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
quant_type=quant_type, quant_type=quant_type,
device="meta", device="meta",
) )
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False) new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
linear_q.weight = new_weight linear_q.weight = new_weight
if bias: if bias:
linear_q.bias = torch.nn.Parameter(linear.bias) linear_q.bias = torch.nn.Parameter(linear.bias)
...@@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# MATCHING # MATCHING
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
# Quantizing original layer with specified quant_storage type
linear_qs = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
assert a.device == b.device assert a.device == b.device
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
...@@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
x = torch.rand(42, layer_shape[0], device=device) x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x) a = linear_q(x)
b = linear_q2(x) b = linear_q2(x)
c = linear_qs(x)
assert a.device == b.device assert a.device == b.device
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert a.device == c.device
assert a.dtype == c.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
assert torch.equal(a, c)
# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment