context_manager.py 1013 Bytes
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from contextlib import contextmanager

from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState
from torch import nn as torch_nn


@contextmanager
def attach_pipeline_state_to_model(model: torch_nn.Module, pipeline_state: PipelineBatchState):
    """Attach the pipeline state to all the PipelineBlocks within `model`"""
    old_pipeline_states = []

    # Set new
    for name, module in model.named_modules():
        if not isinstance(module, PipelineBlock):
            continue

        old_pipeline_state = module.pipeline_state
        assert old_pipeline_state is None, "We never replace an old pipeline engine, we just set one when there's none"

        old_pipeline_states.append((old_pipeline_state, module))

        module.set_pipeline_state(pipeline_state)

    try:
        yield
    finally:
        for old_pipeline_state, module in old_pipeline_states:
            module.set_pipeline_state(old_pipeline_state)