fsdp_utils.py 2.91 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import functools
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

from allamo.logging import logger
from allamo.configuration import AllamoConfiguration
from allamo.model.model import SelfAttentionBlock
from allamo.torch_utils import (
    TORCH_DTYPE_MAP,
)

FSDP_SHARDING_STRATEGY_MAP = {
    'FULL_SHARD': ShardingStrategy.FULL_SHARD,
    'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,
    '_HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,
    'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
    'NO_SHARD': ShardingStrategy.NO_SHARD
}

def enable_activation_checkpointing(model):
    non_reentrant_wrapper = functools.partial(
        checkpoint_wrapper,
        offload_to_cpu=False,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )
    check_fn = lambda submodule: isinstance(submodule, SelfAttentionBlock)
    apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
    logger.info(f"Activation checkpointing applied to the model")

def parallelize_model_with_fsdp1(model, config: AllamoConfiguration, with_activation_checkpointing: bool = False):
    logger.info("Configuring model with FSDP1")
    ptdtype = TORCH_DTYPE_MAP[config.dtype]
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            SelfAttentionBlock,
        },
    )
    sharding_strategy = FSDP_SHARDING_STRATEGY_MAP[config.fsdp_sharding_strategy]
    fsdp_config = dict(
        auto_wrap_policy=auto_wrap_policy,
        sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device(),
        mixed_precision=MixedPrecision(
            param_dtype=ptdtype,
            reduce_dtype=ptdtype,
            buffer_dtype=ptdtype,
        ),
        limit_all_gathers=True,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,  # will use slightly more memory vs. no prefetch
        use_orig_params=True, # required to use torch.compile()
    )
    
    model = FSDP(model, **fsdp_config)
    logger.info(f"Model configured with FSDP1 and {sharding_strategy=}")
    
    if with_activation_checkpointing:
        enable_activation_checkpointing(model)
        
    logger.info(f"Model after parallelization {model=}\n")
    
    if config.compile:
        logger.info("Compiling model")
        try:
            model = torch.compile(model, mode=config.compile_mode)
            logger.info("Model compiled and ready to use")
        except Exception as err:
            logger.warning(f"Unable to compile the model: {err}")
    
    return model