init_env.py 3.98 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -----------mock torch, put it in the first line-----------
import oneflow as flow

flow.mock_torch.enable()

from oneflow import Tensor, nn  # noqa
from transformers import modeling_utils  # noqa
from transformers.modeling_utils import _load_state_dict_into_model  # noqa


# ---------------- mock _load_state_dict_into_model ------------------
def new_load(model_to_load, state_dict, start_prefix):
    # Convert old format to new format if needed from a PyTorch state_dict
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    error_msgs = []

    # TODO: add start_prefix judgement
    for k, v in model_to_load.state_dict().items():
        if k in state_dict and v.is_global:
            state_dict[k] = state_dict[k].to_global(
                sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cpu")
            )
            state_dict[k] = state_dict[k].to_global(
                sbp=v.sbp,
                placement=flow.placement("cpu", ranks=list(v.placement.ranks)),
            )

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module: nn.Module, state_dict, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        # Parameters of module and children will start with prefix.
        # We can exit early if there are none in this state_dict
        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
            module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
                load(child, state_dict, prefix + name + ".")

    load(model_to_load, state_dict, prefix=start_prefix)
    # Delete `state_dict` so it could be collected by GC earlier.
    # Note that `state_dict` is a copy of the argument, so it's safe to delete it.
    del state_dict
    return error_msgs


modeling_utils._load_state_dict_into_model = new_load


# -----------------mock tensor.new_ones() -------------
def flow_ones(self, *args, **kwargs):
    return flow.ones(*args, **kwargs, device=self.device, dtype=self.dtype)


Tensor.new_ones = flow_ones


# -----------------mock tensor.new() ------------------
def flow_zeros(self, *args, **kwargs):
    return flow.zeros(*args, **kwargs, device=self.device, dtype=self.dtype)


Tensor.new = flow_zeros

# ------------------mock nn.functional.softmax---------
temp_func = nn.functional.softmax


def flow_softmax(*args, **kwargs):
    if "dtype" in kwargs:
        _tensor = args[0].to(dtype=kwargs.pop("dtype"))
        return temp_func(_tensor, *args[1:], **kwargs)
    else:
        return temp_func(*args, **kwargs)


nn.functional.softmax = flow_softmax