Unverified Commit 0231ce83 authored by Eldar Kurtić's avatar Eldar Kurtić Committed by GitHub
Browse files

Revert back to torch.equal over torch.allclose from #28819 (#29086)


Signed-off-by: default avatarEldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
parent 516c3f78
......@@ -1055,11 +1055,11 @@ class EagleProposer:
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
and torch.allclose(
# TODO: Offload to CPU for comparison to avoid extra GPU memory
# usage in CI testing environments with limited GPU memory
and torch.equal(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
rtol=1e-5,
atol=1e-7,
)
):
share_embeddings = True
......@@ -1105,8 +1105,11 @@ class EagleProposer:
hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor)
# TODO: Offload to CPU for comparison to avoid extra GPU memory
# usage in CI testing environments with limited GPU memory
and torch.equal(
target_language_model.lm_head.weight, self.model.lm_head.weight
target_language_model.lm_head.weight.cpu(),
self.model.lm_head.weight.cpu(),
)
):
share_lm_head = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment