"vscode:/vscode.git/clone" did not exist on "4fca1a1bd25a1b0d3b49f3fa832425cef5a612fb"
Unverified Commit 72fc97a0 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Bugfix] Fix torch dynamo fixes caused by `replace_parameters` (#8748)

parent 2529d09b
...@@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str, ...@@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str,
new: Union[torch.Tensor, torch.nn.Parameter]) -> None: new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
old = getattr(mod, name) old = getattr(mod, name)
if old.dtype == new.dtype and \ if type(old) is type(new) and old.dtype == new.dtype and \
old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
# If we can just update in-place to avoid re-registering # If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same # can be faster if the underlying storage is the same
update_tensor_inplace(old, new) update_tensor_inplace(old, new)
else: else:
# Fallback re-register parameter # Fallback re-register parameter, convert to Parameter if necessary
# this not only ensures we don't register a tensor as a parameter, but
# also ensures that all parameter subclasses get re-registered as
# parameters for `torch.compile` compatibility
if not isinstance(new, torch.nn.Parameter): if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new) new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name, torch.nn.Parameter(new)) mod.register_parameter(name,
torch.nn.Parameter(new, requires_grad=False))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment