Unverified Commit 3f93fd06 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Llama et al. / FSDP : Fix breaking change in 4.40 for FSDP (#31161)



* fix llama fsdp

* fixup

* adding FSDP tests for CPU offloading

* fixes

* fix tests

* fix tests

* add it for mixtral

* propagate the changes on other models

* Update src/transformers/models/phi/modeling_phi.py

* Delete utils/testing_scripts/fsdp_cpu_offloading.py

Remove script - FSDP + CPU offloading it tested in the test suite

* Delete utils/testing_scripts/dummy_fsdp_config.yml

* Update + add cache_positions docstring

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent ac52084b
...@@ -782,6 +782,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -782,6 +782,7 @@ class FalconDecoderLayer(nn.Module):
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs,
): ):
residual = hidden_states residual = hidden_states
......
...@@ -628,6 +628,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -628,6 +628,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -642,6 +643,11 @@ class GemmaDecoderLayer(nn.Module): ...@@ -642,6 +643,11 @@ class GemmaDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -706,6 +706,7 @@ class GPTBigCodeBlock(nn.Module): ...@@ -706,6 +706,7 @@ class GPTBigCodeBlock(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[ ) -> Union[
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]: ]:
......
...@@ -301,6 +301,7 @@ class LlamaAttention(nn.Module): ...@@ -301,6 +301,7 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -590,6 +591,7 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -590,6 +591,7 @@ class LlamaSdpaAttention(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
...@@ -687,6 +689,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -687,6 +689,7 @@ class LlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -701,6 +704,11 @@ class LlamaDecoderLayer(nn.Module): ...@@ -701,6 +704,11 @@ class LlamaDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -591,6 +591,7 @@ class MistralSdpaAttention(MistralAttention): ...@@ -591,6 +591,7 @@ class MistralSdpaAttention(MistralAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
...@@ -689,6 +690,7 @@ class MistralDecoderLayer(nn.Module): ...@@ -689,6 +690,7 @@ class MistralDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -703,8 +705,12 @@ class MistralDecoderLayer(nn.Module): ...@@ -703,8 +705,12 @@ class MistralDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
......
...@@ -888,6 +888,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -888,6 +888,7 @@ class MixtralDecoderLayer(nn.Module):
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -906,6 +907,9 @@ class MixtralDecoderLayer(nn.Module): ...@@ -906,6 +907,9 @@ class MixtralDecoderLayer(nn.Module):
(see `past_key_values`). (see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -666,6 +666,7 @@ class OlmoDecoderLayer(nn.Module): ...@@ -666,6 +666,7 @@ class OlmoDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -680,6 +681,11 @@ class OlmoDecoderLayer(nn.Module): ...@@ -680,6 +681,11 @@ class OlmoDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -771,6 +771,7 @@ class PhiDecoderLayer(nn.Module): ...@@ -771,6 +771,7 @@ class PhiDecoderLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -790,6 +791,9 @@ class PhiDecoderLayer(nn.Module): ...@@ -790,6 +791,9 @@ class PhiDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -830,6 +830,7 @@ class Phi3DecoderLayer(nn.Module): ...@@ -830,6 +830,7 @@ class Phi3DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -847,6 +848,11 @@ class Phi3DecoderLayer(nn.Module): ...@@ -847,6 +848,11 @@ class Phi3DecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -734,6 +734,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -734,6 +734,7 @@ class Qwen2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -749,6 +750,9 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -749,6 +750,9 @@ class Qwen2DecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -876,6 +876,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -876,6 +876,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -894,6 +895,9 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -894,6 +895,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -713,6 +713,7 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -713,6 +713,7 @@ class Starcoder2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
...@@ -728,6 +729,9 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -728,6 +729,9 @@ class Starcoder2DecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import itertools import itertools
import os import os
import subprocess
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
...@@ -31,6 +32,7 @@ from transformers.testing_utils import ( ...@@ -31,6 +32,7 @@ from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_fsdp, require_fsdp,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator, require_torch_multi_accelerator,
slow, slow,
torch_device, torch_device,
...@@ -276,6 +278,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): ...@@ -276,6 +278,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
if "learning_rate" in log: if "learning_rate" in log:
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5) self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
@require_torch_multi_accelerator
@slow
@require_torch_gpu
@require_fsdp
def test_fsdp_cpu_offloading(self):
try:
subprocess.run(
"accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
shell=True,
check=True,
)
except: # noqa
raise AssertionError("CPU offloading failed with FSDP!")
def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir): def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir):
if not use_accelerate: if not use_accelerate:
fsdp_args = [ fsdp_args = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment