Unverified Commit ed5ae4aa authored by Kyuyeun Kim's avatar Kyuyeun Kim Committed by GitHub
Browse files

[Bugfix] Fix _synced_weight_loader (#24565)


Signed-off-by: default avatarKyuyeun Kim <kyuyeunk@google.com>
parent 0fc36463
......@@ -52,10 +52,11 @@ def set_weight_attrs(
def _make_synced_weight_loader(original_weight_loader):
def _synced_weight_loader(param, *args, **kwargs):
original_weight_loader(param, *args, **kwargs)
out = original_weight_loader(param, *args, **kwargs)
# torch._sync doesn't support, is not needed for CPU tensors.
if param.device != torch.device("cpu"):
torch._sync(param)
return out
return _synced_weight_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment