Unverified Commit b340d907 authored by Jiewen Tan's avatar Jiewen Tan Committed by GitHub
Browse files

[PyTorch/XLA] Fix extra TPU compilations introduced by recent changes (#29158)

* tmp

* Remove debug step

* Fix a typo

* Move to is_torch_xla_available
parent 1e21c4fb
......@@ -1364,7 +1364,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation in [None, "sdpa"]:
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment