fsdp.py 3.02 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Any
from typing import Callable

from torch.distributed.fsdp import FullyShardedDataParallel


class _FSDPForwardRedirection:
    """
    Modified based on
    https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
    Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
    post-forward can be properly executed around the method call.
    This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
    the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
    GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
    will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
    the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
    its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
    the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
    """

    def __call__(
        self,
        wrapper_module: FullyShardedDataParallel,
        method: Callable,
        *args: Any,
        **kwargs: Any,
    ):
        """Reroutes a method call through the `wrapper_module`'s `forward` method.
        Args:
            wrapper_module: The module that has `original_module` wrapped.
            original_module: The module that was wrapped inside `wrapper_module`.
            method_name: The name of the method that should be called on the `original_module` after inputs get
                redirected through the `wrapper_module`'s `forward` method.
            *args: The positional arguments to the method `method_name`. They will get passed to a patched
                `forward` method instead.
            **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
                `forward` method instead.
        """
        assert isinstance(wrapper_module, FullyShardedDataParallel)
        original_module = wrapper_module._fsdp_wrapped_module
        original_forward = original_module.forward

        def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
            # Unpatch ourselves immediately before calling the method `method_name`
            # because itself may want to call the real `forward`
            original_module.forward = original_forward  # type: ignore[method-assign]
            # Call the actual method e.g. `.training_step(...)`
            out = method(*_args, **_kwargs)
            return out

        # Patch the original_module's forward so we can redirect the arguments back to the real method
        original_module.forward = wrapped_forward  # type: ignore[method-assign]
        wrapper_output = wrapper_module(*args, **kwargs)
        return wrapper_output