accelerator.py 1019 Bytes
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) Alibaba, Inc. and its affiliates.


def ta_accelerate(model,
                  fsdp_num,
                  layer_cls_name,
                  bf16=True,
                  fp16=False,
                  gradient_checkpointing=True,
                  fsdp_flatten_parameters=False):
    """ accelerate LLM training using TorchAcc(only available internally).
    """
    import torchacc as ta
    assert layer_cls_name is not None

    def get_ta_config():
        config = ta.Config()
        config.compute.fp16 = fp16
        config.compute.bf16 = bf16

        config.memory.gc = gradient_checkpointing
        if config.memory.gc:
            config.memory.gc_cls = {layer_cls_name}

        config.dist.fsdp.size = fsdp_num
        config.dist.fsdp.wrap_layer_cls = {layer_cls_name}
        config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters
        config.dist.dp.size = 1

        return config

    ta_config = get_ta_config()
    model = ta.accelerate(model, ta_config)
    return model