parallelize_dcp.py 3.03 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import shutil
import torch
import torch.distributed as dist

from allamo.checkpoint.checkpoint_manager import CheckpointManager
from allamo.configuration import AllamoConfiguration
from allamo.logging import configure_logger, logger
from allamo.model.model import AllamoTransformer
from allamo.torch_utils import init_torch
from allamo.train_utils import (
    get_config_checkpoint_path,
    create_model_config,
)
from allamo.training_context import TrainingContext
from allamo.parallelisms.fsdp2_utils import build_world_mesh, parallelize_model_with_fsdp2
            
def parallelize_dcp(args):
    train_ctx = TrainingContext(
        tp = args.tp_size,
    )
    if train_ctx.master_process:
        configure_logger()
    
    config = AllamoConfiguration(
        load_configuration = False,
        device=args.device,
        backend=args.backend,
        tensor_parallel_degree = args.tp_size,
        out_dir = args.dst,
        compile = False,
    )
    init_torch(train_ctx, config)
    logger.info(f"Torch initialized")
    
    config_ckpt_file_path = get_config_checkpoint_path(args.checkpoint_name, args.src)
    checkpoint_manager = CheckpointManager(config, train_ctx, None)
    checkpoint_manager.load_config_checkpoint(config_ckpt_file_path)
    logger.info("Config checkpoint loaded")
    
    world_mesh = build_world_mesh(train_ctx, args.device)
    model_config = create_model_config(config)
    with torch.device('meta'):
        model = AllamoTransformer(model_config)
    logger.info("Model initialized on meta device")
    
    model.to_empty(device=args.device)
    model.init_model_weights()
    logger.info(f"Model weights initialized on {args.device} device")
    
    checkpoint_manager.load_distributed_model_checkpoint_from(model, args.src, args.checkpoint_name)
    logger.info("Model checkpoint loaded")
    
    model = parallelize_model_with_fsdp2(model, world_mesh, config, False)
    
    checkpoint_manager.save_distributed_model_checkpoint_to(model, args.dst, args.checkpoint_name)
    shutil.copy(config_ckpt_file_path, get_config_checkpoint_path(args.checkpoint_name, args.dst))
    
    dist.barrier()
    dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--src', type=str, required=True, help="Path to source checkpoint directory")
    parser.add_argument('-d', '--dst', type=str, required=True, help="Path to target checkpoint directory")
    parser.add_argument('-n', '--checkpoint_name', type=str, required=True, help="Checkpoint name")
    parser.add_argument('-t', '--tp_size', type=int, required=True, help="Tensor parallel degree")
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cpu', help="Device type to run conversion on")
    parser.add_argument('--backend', type=str, choices=['gloo', 'mpi', 'nccl'], default='gloo', help="Specifies one of three built-in backends")
    
    args = parser.parse_args()
    
    parallelize_dcp(args)
    logger.info(f"Distributed checkpoint parallelized with TP={args.tp_size}")