Unverified Commit 42653921 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Merge pull request #1719 from ved1beta/fsdp_integration2

Fix Params4bit tensor subclass handling
parents e54dc125 0ecb8fb4
...@@ -356,6 +356,46 @@ class Params4bit(torch.nn.Parameter): ...@@ -356,6 +356,46 @@ class Params4bit(torch.nn.Parameter):
return new_param return new_param
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
else:
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return super().__torch_function__(func, types, args, kwargs)
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None: if getattr(module.weight, "quant_state", None) is not None:
......
...@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): ...@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_params4bit_torch_chunk_split(device, quant_type):
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
pytest.skip("This configuration is not supported on HPU.")
if device == "cpu":
pytest.skip("CPU quantization causes segfault, skipping CPU test")
original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
if device != "cpu":
params4bit = params4bit.to(device)
chunks = torch.chunk(params4bit, 2, dim=0)
assert isinstance(chunks, tuple), "torch.chunk should return tuple"
for chunk in chunks:
assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
assert hasattr(chunk, "quant_type"), "Should preserve metadata"
assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
splits = torch.split(params4bit, 2, dim=0)
assert isinstance(splits, tuple), "torch.split should return tuple"
assert len(splits) > 0, "Should have at least one split"
for split in splits:
assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
assert hasattr(split, "quant_type"), "Should preserve metadata"
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment