"...resnet50_tensorflow.git" did not exist on "d6c6e44128237ebcccdfeb92b5898f45f8901f31"
Unverified Commit 60825f2c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix device placement for model-parallelism in generate for encoder/de… (#24025)

* Fix device placement for model-parallelism in generate for encoder/decoders

* Remove debug statements
parent 02d255db
...@@ -616,6 +616,10 @@ class GenerationMixin: ...@@ -616,6 +616,10 @@ class GenerationMixin:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 1. get encoder # 1. get encoder
encoder = self.get_encoder() encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
# 2. Prepare encoder args and encoder kwargs from model kwargs. # 2. Prepare encoder args and encoder kwargs from model kwargs.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment