Commit 2cd047e3 authored by justheuristic's avatar justheuristic
Browse files

run backward

parent 591f6039
...@@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda"
if memory_efficient_backward:
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
o1 = mlp(b1)
assert o1.dtype == torch.float16
assert o1.requires_grad
grad_proj = torch.randn_like(o1)
(o1 * grad_proj).sum().backward()
def test_linear8bitlt_fp32_bias(): def test_linear8bitlt_fp32_bias():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment