Unverified Commit dcfb6f81 authored by Benjamin Warner's avatar Benjamin Warner Committed by GitHub
Browse files

Initial FSDP Support for QLoRA Finetuning (#970)



This PR adds initial FSDP support for training QLoRA models. It enables basic FSDP and CPU Offload support, with low memory training via FSDP.sync_module_states option unsupported.

This PR builds off of #840 commit 8278fca and BNB FSDP by @TimDettmers and @Titus-von-Koeller.

An example of using this PR to finetune QLoRA models with FSDP can be found in the demo repo: AnswerDotAi/fsdp_qlora.

* Minimal changes for fp32 4bit storage from BNB commit 8278fca

* Params4bit with selectable storage dtype

* possible fix for double quantizing linear weight & quant storage dtype

* minor fixes in Params4bit for peft tests

* remove redundant

* add float16

* update test

* Remove float16 quant cast as there are fp32, bf16, & fp16 quant kernels

---------
Co-authored-by: default avatarKerem Turgutlu <keremturgutlu@gmail.com>
parent 64a28d02
...@@ -607,7 +607,7 @@ class QuantState: ...@@ -607,7 +607,7 @@ class QuantState:
qs_dict: based on state_dict, with only relevant keys, striped of prefixes. qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
""" """
# unpacking tensor with non-tensor components # unpacking tensor with non-tensor components
...@@ -802,7 +802,7 @@ def dequantize_blockwise( ...@@ -802,7 +802,7 @@ def dequantize_blockwise(
if quant_state is None: if quant_state is None:
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
absmax = quant_state.absmax absmax = quant_state.absmax
if quant_state.nested: if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
...@@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device) return data.to(device)
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor:
""" """
Quantize tensor A in blocks of 4-bit values. Quantize tensor A in blocks of 4-bit values.
...@@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
absmax : torch.Tensor absmax : torch.Tensor
The absmax values. The absmax values.
out : torch.Tensor out : torch.Tensor
The output tensor (8-bit). The output tensor.
blocksize : int blocksize : int
The blocksize used in quantization. The blocksize used in quantization.
quant_type : str quant_type : str
...@@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
Returns Returns
------- -------
torch.Tensor: torch.Tensor:
The 8-bit tensor with packed 4-bit values. Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int): tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization. The quantization state to undo the quantization.
""" """
...@@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz ...@@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if out is None: if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
...@@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = ...@@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor =
Parameters Parameters
---------- ----------
A : torch.Tensor A : torch.Tensor
The input 8-bit tensor (packed 4-bit values). The input tensor (packed 4-bit values).
quant_state : QuantState quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype. object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor absmax : torch.Tensor
...@@ -1626,7 +1627,7 @@ def gemv_4bit( ...@@ -1626,7 +1627,7 @@ def gemv_4bit(
ldb = ct.c_int32(ldb) ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc) ldc = ct.c_int32(ldc)
if B.dtype == torch.uint8: if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
......
...@@ -141,8 +141,18 @@ class Embedding(torch.nn.Embedding): ...@@ -141,8 +141,18 @@ class Embedding(torch.nn.Embedding):
class Params4bit(torch.nn.Parameter): class Params4bit(torch.nn.Parameter):
def __new__(
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit": cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
quant_state: QuantState = None,
blocksize: int = 64,
compress_statistics: bool = True,
quant_type: str = 'fp4',
quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False
) -> "Params4bit":
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -151,7 +161,10 @@ class Params4bit(torch.nn.Parameter): ...@@ -151,7 +161,10 @@ class Params4bit(torch.nn.Parameter):
self.compress_statistics = compress_statistics self.compress_statistics = compress_statistics
self.quant_type = quant_type self.quant_type = quant_type
self.quant_state = quant_state self.quant_state = quant_state
self.quant_storage = quant_storage
self.bnb_quantized = bnb_quantized
self.data = data self.data = data
self.module = module
return self return self
@classmethod @classmethod
...@@ -162,16 +175,23 @@ class Params4bit(torch.nn.Parameter): ...@@ -162,16 +175,23 @@ class Params4bit(torch.nn.Parameter):
self.blocksize = self.quant_state.blocksize self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type self.quant_type = self.quant_state.quant_type
self.bnb_quantized = True
return self return self
def cuda(self, device): def _quantize(self, device):
w = self.data.contiguous().half().cuda(device) w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
self.data = w_4bit self.data = w_4bit
self.quant_state = quant_state self.quant_state = quant_state
if self.module is not None:
self.module.quant_state = quant_state
self.bnb_quantized = True
return self return self
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)
@overload @overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
... ...
...@@ -187,8 +207,8 @@ class Params4bit(torch.nn.Parameter): ...@@ -187,8 +207,8 @@ class Params4bit(torch.nn.Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): if (device is not None and device.type == "cuda" and not self.bnb_quantized):
return self.cuda(device) return self._quantize(device)
else: else:
if self.quant_state is not None: if self.quant_state is not None:
self.quant_state.to(device) self.quant_state.to(device)
...@@ -203,12 +223,14 @@ class Params4bit(torch.nn.Parameter): ...@@ -203,12 +223,14 @@ class Params4bit(torch.nn.Parameter):
class Linear4bit(nn.Linear): class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False self.compute_type_is_set = False
self.quant_state = None
self.quant_storage = quant_storage
def set_compute_type(self, x): def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]: if x.dtype in [torch.float32, torch.bfloat16]:
...@@ -243,7 +265,15 @@ class Linear4bit(nn.Linear): ...@@ -243,7 +265,15 @@ class Linear4bit(nn.Linear):
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') if getattr(self, 'quant_state', None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set: if not self.compute_type_is_set:
self.set_compute_type(x) self.set_compute_type(x)
self.compute_type_is_set = True self.compute_type_is_set = True
...@@ -261,8 +291,8 @@ class Linear4bit(nn.Linear): ...@@ -261,8 +291,8 @@ class Linear4bit(nn.Linear):
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
class LinearNF4(Linear4bit): class LinearNF4(Linear4bit):
...@@ -276,8 +306,8 @@ class LinearNF4(Linear4bit): ...@@ -276,8 +306,8 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' '''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
......
...@@ -2370,7 +2370,8 @@ def test_normal_map_tree(): ...@@ -2370,7 +2370,8 @@ def test_normal_map_tree():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, double_quant, kind): @pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]: for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]: #for dim in [4*1024]:
#for dim in [1*16]: #for dim in [1*16]:
...@@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind): ...@@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
A = torch.randn(1, dim, dtype=dtype, device='cuda') A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state) C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True A.requires_grad = True
......
...@@ -8,13 +8,19 @@ import torch ...@@ -8,13 +8,19 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quant_type, compress_statistics, bias", "quant_type, compress_statistics, bias, quant_storage",
list(product(["nf4", "fp4"], [False, True], [False, True])), list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),
) )
def test_linear_serialization(quant_type, compress_statistics, bias): def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
original_dtype = torch.float16 original_dtype = torch.float16
compute_dtype = None compute_dtype = None
device = "cuda" device = "cuda"
...@@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
quant_type=quant_type, quant_type=quant_type,
device="meta", device="meta",
) )
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False) new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
linear_q.weight = new_weight linear_q.weight = new_weight
if bias: if bias:
linear_q.bias = torch.nn.Parameter(linear.bias) linear_q.bias = torch.nn.Parameter(linear.bias)
...@@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# MATCHING # MATCHING
a, b = linear_q.weight, linear_q2.weight a, b = linear_q.weight, linear_q2.weight
# Quantizing original layer with specified quant_storage type
linear_qs = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)
assert a.device == b.device assert a.device == b.device
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
...@@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias): ...@@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
x = torch.rand(42, layer_shape[0], device=device) x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x) a = linear_q(x)
b = linear_q2(x) b = linear_q2(x)
c = linear_qs(x)
assert a.device == b.device assert a.device == b.device
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert a.device == c.device
assert a.dtype == c.dtype
assert torch.equal(a, b) assert torch.equal(a, b)
assert torch.equal(a, c)
# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment