Unverified Commit b8d1c261 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Linear8bitLt: support device movement after forward() (#1769)

parent 42e8abc3
...@@ -679,19 +679,27 @@ class Int8Params(torch.nn.Parameter): ...@@ -679,19 +679,27 @@ class Int8Params(torch.nn.Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type != "meta" and self.data.device.type == "cpu": is_quantized = self.data.dtype == torch.int8
if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device)
elif self.data.dtype == torch.int8 and device.type == "cpu":
self.CB = self.data
if not is_quantized and device is not None and device.type != "meta" and self.data.device.type == "cpu":
# We're moving from a CPU device to a non-meta device.
# In this circumstance, we want to quantize if we haven't already.
return self._quantize(device)
# Create a new parameter on the target device.
new_param = Int8Params( new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking), super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights, has_fp16_weights=self.has_fp16_weights,
) )
new_param.CB = self.CB
new_param.SCB = self.SCB # If we had already quantized, move the statistics appropriately.
if is_quantized and device is not None:
if self.CB is not None:
new_param.CB = new_param.data
if self.SCB is not None:
new_param.SCB = self.SCB.to(device)
return new_param return new_param
...@@ -1037,6 +1045,21 @@ class Linear8bitLt(nn.Linear): ...@@ -1037,6 +1045,21 @@ class Linear8bitLt(nn.Linear):
self.weight.CB = None self.weight.CB = None
self.weight.SCB = None self.weight.SCB = None
def to(self, *args, **kwargs):
# Call the parent to() method to handle standard parameter/buffer movement
result = super().to(*args, **kwargs)
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
# Handle state tensors if needed.
if device is not None:
if result.state.CB is not None:
result.state.CB = result.state.CB.to(device)
if result.state.SCB is not None:
result.state.SCB = result.state.SCB.to(device)
return result
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
self.state.is_training = self.training self.state.is_training = self.training
if self.weight.CB is not None: if self.weight.CB is not None:
......
...@@ -293,3 +293,41 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): ...@@ -293,3 +293,41 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
grad_compiled = x.grad.clone() grad_compiled = x.grad.clone()
torch.testing.assert_close(grad_compiled, grad_ref) torch.testing.assert_close(grad_compiled, grad_ref)
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
def test_linear8bitlt_device_movement(device):
"""Test moving a Linear8bitLt layer between CPU and an accelerator device."""
# Create a Linear8bitLt layer on CPU
layer = bnb.nn.Linear8bitLt(32, 128, bias=False, has_fp16_weights=False)
torch.nn.init.xavier_uniform_(layer.weight)
# Create a sample input.
x = torch.randn(4, 32, dtype=torch.float16, device="cpu")
# Move to the device. This should quantize the weights.
layer = layer.to(device)
assert layer.weight.data.dtype == torch.int8
# Call the layer on the accelerator device.
out_accelerator = layer(x.to(device))
# Move back to CPU and call again.
layer = layer.to("cpu")
out_cpu = layer(x)
# Move back to the accelerator device and call again.
layer = layer.to(device)
out_accelerator_2 = layer(x.to(device))
# Move back to the CPU and call one last time.
layer = layer.to("cpu")
out_cpu_2 = layer(x)
# CPU outputs should match both times.
torch.testing.assert_close(out_cpu_2, out_cpu, rtol=1e-8, atol=1e-8)
# Accelerator outputs should match both times.
torch.testing.assert_close(out_accelerator_2, out_accelerator, rtol=1e-8, atol=1e-8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment