distributed_fused_adam.py 2.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import torch

from megatron import mpu

# >>>
from lutil import pax, tp
# <<<

class DistributedFusedAdam(torch.optim.Optimizer):

    def __init__(self, params):

        super().__init__(params, defaults = {})

        self.initialized = False
        # self.params_32 = None
        # self.grads_32 = None
        # self.opt_m = None
        # self.opt_v = None

        # pax(0, {
        #     "param_groups" : self.param_groups,
        #     "param_groups / 0" : self.param_groups[0],
        #     "param_groups / 1" : self.param_groups[1],
        #     "param_groups / 0 / params" : self.param_groups[0]["params"],
        #     # "param_groups / params" : [ g["params"] for g in self.param_groups ],
        # })

    def initialize(self):

        if self.initialized:
            raise Exception("initialization worked.")
            return
        self.initialized = True

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        total_param_size = sum(
            p.numel()
            for g in self.param_groups
            for p in g["params"]
        )
        shard_size = int(math.ceil(total_param_size / data_parallel_world_size))
        shard_start_index = data_parallel_rank * shard_size
        shard_end_index = min(total_param_size, shard_start_index + shard_size)
        shard_size = shard_end_index - shard_start_index

        allocate_shard = lambda dtype : torch.empty(
            [shard_size],
            dtype = dtype,
            device = torch.cuda.current_device())

        self.main_param_shard = allocate_shard(torch.float)
        self.main_grad_shard = allocate_shard(torch.float)
        self.adam_m_shard = allocate_shard(torch.float)
        self.adam_v_shard = allocate_shard(torch.float)

        # pax(2, {
        #     "data_parallel_rank" : data_parallel_rank,
        #     "data_parallel_world_size" : data_parallel_world_size,
        #     "total_param_size" : total_param_size,
        #     "shard_size" : shard_size,
        #     "shard" : "%d [ %d, %d ]" % (
        #         shard_size,
        #         shard_start_index,
        #         shard_end_index,
        #     ),
        # })

    def step(self):

        self.initialize()

        raise Exception("what's next?")

# >>>
# eof
# <<<