Unverified Commit b340d907 authored by Jiewen Tan's avatar Jiewen Tan Committed by GitHub
Browse files

[PyTorch/XLA] Fix extra TPU compilations introduced by recent changes (#29158)

* tmp

* Remove debug step

* Fix a typo

* Move to is_torch_xla_available
parent 1e21c4fb
...@@ -1364,7 +1364,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1364,7 +1364,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hard_check_only=False, hard_check_only=False,
check_device_map=check_device_map, check_device_map=check_device_map,
) )
elif requested_attn_implementation in [None, "sdpa"]: elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa( config = cls._check_and_enable_sdpa(
config, config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment