Unverified Commit 35c570c8 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

fix encoder hook (#25735)

* fix encoder hook

* style
parent dd8b7d28
...@@ -33,7 +33,7 @@ from ..models.auto import ( ...@@ -33,7 +33,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING,
) )
from ..utils import ExplicitEnum, ModelOutput, logging from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig from .configuration_utils import GenerationConfig
...@@ -80,6 +80,9 @@ if TYPE_CHECKING: ...@@ -80,6 +80,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
@dataclass @dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput): class GreedySearchDecoderOnlyOutput(ModelOutput):
...@@ -631,8 +634,11 @@ class GenerationMixin: ...@@ -631,8 +634,11 @@ class GenerationMixin:
encoder = self.get_encoder() encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs. # as the inputs.
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"): if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
# 2. Prepare encoder args and encoder kwargs from model kwargs. # 2. Prepare encoder args and encoder kwargs from model kwargs.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment