Unverified Commit 29c36e9f authored by songyouwei's avatar songyouwei Committed by GitHub
Browse files

run_pplm.py bug fix (#4867)

`is_leaf` may become `False` after `.to(device=device)` function call.
parent 13aa1741
......@@ -148,6 +148,9 @@ def perturb_past(
for i in range(num_iterations):
print("Iteration ", i + 1)
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
# make sure p_.grad is not None
for p_ in curr_perturbation:
p_.retain_grad()
# Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment