Unverified Commit ecde7af9 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix import that was moved in Transformers 5.2.0 (#36120)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 8df52335
...@@ -516,8 +516,11 @@ class Base( ...@@ -516,8 +516,11 @@ class Base(
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.check_version("5.0.0", "Eagle3 support") self.check_version("5.2.0", "Eagle3 support")
from transformers.utils.generic import OutputRecorder from transformers.utils.output_capturing import (
OutputRecorder,
maybe_install_capturing_hooks,
)
# The default value in PreTrainedModel is None # The default value in PreTrainedModel is None
if self.model._can_record_outputs is None: if self.model._can_record_outputs is None:
...@@ -532,6 +535,9 @@ class Base( ...@@ -532,6 +535,9 @@ class Base(
self.model._can_record_outputs[layer_key] = aux_hidden_state_i self.model._can_record_outputs[layer_key] = aux_hidden_state_i
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
# Ensure that the capture hooks are installed before dynamo traces the model
maybe_install_capturing_hooks(self.model)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment