fsdp2_utils.py 5.31 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Tensor Parallelism API requires PyTorch 2.3.0+
"""
from allamo.logging import logger
from allamo.configuration import AllamoConfiguration
from allamo.torch_utils import (
    TORCH_DTYPE_MAP,
)
from allamo.training_context import TrainingContext

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)
from torch.distributed.device_mesh import init_device_mesh, DeviceMesh
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    PrepareModuleInput,
    SequenceParallel,
)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.utils.checkpoint import checkpoint

def build_world_mesh(train_ctx: TrainingContext, device_type: str = "cuda"):
    dims = (train_ctx.pp, train_ctx.dp, train_ctx.tp)
    dim_names = ("pp", "dp", "tp")
    device_mesh = init_device_mesh(device_type, dims, mesh_dim_names=dim_names)
    logger.info(f"{len(dims)}-D device mesh built: {dim_names} = {dims}")
    return device_mesh

def parallelize_model_with_fsdp2(model, world_mesh, config, with_activation_checkpointing):
    if world_mesh['tp'].size() > 1:
        apply_tensor_parallelism(model, world_mesh)
    
    if with_activation_checkpointing:
        apply_activation_checkpointing(model)
    
    apply_fsdp(model, world_mesh, config)
    
    if config.compile:
        logger.info("Compiling model")
        try:
            model = torch.compile(model, mode=config.compile_mode)
            logger.info("Model compiled and ready to use")
        except Exception as err:
            logger.warning(f"Unable to compile the model: {err}")
    return model

def apply_tensor_parallelism(model: nn.Module, world_mesh: DeviceMesh):
    logger.warning(
        "Tensor parallelism is in an early experimental stage. "
        "Strided sharding is required for 2D/3D DCP, but it is only available in nightly builds "
        "newer than 20240809 and in PyTorch version 2.5 or later."
    )    
    parallelize_module(
        model,
        world_mesh["tp"],
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
            ),
            "norm": SequenceParallel(),
            "lm_head": ColwiseParallel(
                input_layouts=Shard(1),
                output_layouts=Replicate(),
                use_local_output=True,
            ),
        },
    )
    
    for layer in model.layers:
        layer_plan = {
            "attention_norm": SequenceParallel(),
            "attention": PrepareModuleInput(
                input_layouts=(Shard(1), None, None),
                desired_input_layouts=(Replicate(), None, None),
            ),
            "attention.q_proj": ColwiseParallel(),
            "attention.k_proj": ColwiseParallel(),
            "attention.v_proj": ColwiseParallel(),
            "attention.c_proj": RowwiseParallel(output_layouts=Shard(1)),
            "ffn_norm": SequenceParallel(),
            "feed_forward": PrepareModuleInput(
                input_layouts=(Shard(1),),
                desired_input_layouts=(Replicate(),),
            ),
            "feed_forward.gate_proj": ColwiseParallel(),
            "feed_forward.down_proj": RowwiseParallel(output_layouts=Shard(1)),
            "feed_forward.up_proj": ColwiseParallel(),
        }
        
        layer.attention.num_heads //= world_mesh["tp"].size()
        layer.attention.num_kv_heads //= world_mesh["tp"].size()
        
        parallelize_module(
            module=layer,
            device_mesh=world_mesh["tp"],
            parallelize_plan=layer_plan,
        )
    logger.info(f"Model parallelized with Tensor Parallelism (size: {world_mesh['tp'].size()})")

def apply_activation_checkpointing(model: nn.Module):
    for layer_id in range(len(model.layers)):
        model.layers[layer_id] = checkpoint_wrapper(
            model.layers[layer_id],
            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
            checkpoint_fn=checkpoint,
            use_reentrant=False,
            preserve_rng_state=False,
        )

def apply_fsdp(model: nn.Module, world_mesh: DeviceMesh, config: AllamoConfiguration):
    fsdp_config = {"mesh": world_mesh["dp"]}
    if config.dtype != 'float32':
        fsdp_config["mp_policy"] = MixedPrecisionPolicy(
            param_dtype=TORCH_DTYPE_MAP[config.dtype],
            reduce_dtype=torch.float32
        )
    pp_enabled = world_mesh['pp'].size() > 1
    
    for layer_id, layer in enumerate(model.layers):
        if pp_enabled:
            # For PP, do not reshard after forward to avoid per-microbatch
            # all-gathers, which can be expensive and non-overlapped
            reshard_after_forward = False
        else:
            # As an optimization, do not reshard after forward for the last
            # layer since FSDP would prefetch it immediately
            reshard_after_forward = int(layer_id) < len(model.layers) - 1
        fully_shard(
            layer,
            **fsdp_config,
            reshard_after_forward=reshard_after_forward,
        )
    fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
    logger.info(f"Model parallelized with FSDP2: {model}\n")