"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "af6732225c0c789f18bdff812a0733c173480969"
Unverified Commit 91f3dfbf authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Adafactor] Fix adafactor (#14713)

* correct changes

* add comment
parent 86dd23bb
......@@ -503,9 +503,11 @@ class Adafactor(Optimizer):
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_()
c_factor = exp_avg_sq_col.rsqrt()
return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def step(self, closure=None):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment