Commit 2938c739 authored by ved1beta's avatar ved1beta
Browse files

test_params4bit_torch_chunk_split

parent 1dbe6021
...@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): ...@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_params4bit_torch_chunk_split(device, quant_type):
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
pytest.skip("This configuration is not supported on HPU.")
if device == "cpu":
pytest.skip("CPU quantization causes segfault, skipping CPU test")
original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
if device != "cpu":
params4bit = params4bit.to(device)
chunks = torch.chunk(params4bit, 2, dim=0)
assert isinstance(chunks, tuple), "torch.chunk should return tuple"
for chunk in chunks:
assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
assert hasattr(chunk, "quant_type"), "Should preserve metadata"
assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
splits = torch.split(params4bit, 2, dim=0)
assert isinstance(splits, tuple), "torch.split should return tuple"
assert len(splits) > 0, "Should have at least one split"
for split in splits:
assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
assert hasattr(split, "quant_type"), "Should preserve metadata"
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment