Commit e9f3605f authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Fix Linear4bit warnings/test for compute dtype

parent 812ef06a
...@@ -455,14 +455,14 @@ class Linear4bit(nn.Linear): ...@@ -455,14 +455,14 @@ class Linear4bit(nn.Linear):
self.compute_dtype = x.dtype self.compute_dtype = x.dtype
elif x.dtype == torch.float16: elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer # we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this # warn the user about this
warnings.warn( warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.", "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
) )
warnings.filterwarnings("ignore", message=".*inference.") warnings.filterwarnings("ignore", message=".*inference.")
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
warnings.warn( warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.", "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
) )
......
...@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device): ...@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
dim1 = 64 dim1 = 64
with pytest.warns(UserWarning, match=r"inference or training"): with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential( net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device) net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16) inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
with pytest.warns(UserWarning, match=r"inference."): with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential( net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device) net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16) inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
with pytest.warns(UserWarning) as record: with pytest.warns(UserWarning) as record:
net = nn.Sequential( net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device) net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16) inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
net = nn.Sequential( net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device) net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16) inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment