Commit 5988d2cc authored by yuguo960516's avatar yuguo960516
Browse files

bert-large

parent 478602ba
Pipeline #142 canceled with stages
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import oneflow as flow
from oneflow import nn
from libai.layers import TransformerLayer
from libai.utils import distributed as dist
logger = logging.getLogger(__name__)
class GraphBase(nn.Graph):
def __init__(
self,
model: nn.Module,
optimizer: flow.optim.Optimizer = None,
lr_scheduler: flow.optim.lr_scheduler = None,
fp16=False,
activation_checkpoint=False,
grad_acc_steps=1,
zero_optim=False,
zero_stage=0,
is_train=True,
auto_parallel_conf=None,
):
super().__init__()
self.model = model
self.is_train = is_train
if is_train:
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
if fp16:
self.config.enable_amp(True)
grad_scaler = flow.amp.GradScaler(
init_scale=65536.0 * dist.get_data_parallel_size(),
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
)
self.set_grad_scaler(grad_scaler)
if grad_acc_steps > 1:
self.config.set_gradient_accumulation_steps(grad_acc_steps)
if activation_checkpoint:
self.set_activation_checkpoint()
if zero_optim:
self.config.enable_zero(True, stage=zero_stage)
self.set_pipeline_stage_id()
self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_model_update_ops(True)
self.config.allow_fuse_cast_scale(True)
# Enable cuda stream for computation and communication as the same stream.
# This will reduce memory when using model parallelism.
dist_util = dist.get_dist_util()
if dist_util.is_tensor_model_parallel() or dist_util.is_pipeline_model_parallel():
flow.boxing.nccl.enable_use_compute_stream(True)
# auto_parallel
if auto_parallel_conf is not None and auto_parallel_conf.enabled:
try:
self.config.enable_auto_parallel(True)
self.config.enable_auto_parallel_ignore_user_sbp_config(
auto_parallel_conf.enable_auto_parallel_ignore_user_sbp_config
)
self.config.set_auto_parallel_computation_cost_ratio(0.05)
self.config.set_auto_parallel_wait_time(1.65e4)
self.config.enable_auto_parallel_trunk_algo(auto_parallel_conf.trunk_algo)
self.config.enable_auto_parallel_sbp_collector(auto_parallel_conf.sbp_collector)
except RuntimeWarning:
import warnings
warnings.warn(
"The version of oneflow don't support auto_parallel.\n"
"Please reinstall the oneflow nightly:\n"
"python3 -m pip install --pre oneflow -f https://staging.oneflow.info/branch/master/[PLATFORM]" # noqa
)
def build(self, **kwargs):
if self.is_train:
logger.info(
"Start compling the train graph which may take some time. "
"Please wait for a moment ..."
)
loss_dict = self.model(**kwargs)
losses = sum(v for k, v in loss_dict.items() if "loss" in k)
losses.backward()
return loss_dict
else:
logger.info(
"Start compling the eval graph which may take some time. "
"Please wait for a moment ..."
)
return self.model(**kwargs)
def set_activation_checkpoint(self):
if hasattr(self.model, "origin"):
if hasattr(type(self.model.origin), "set_activation_checkpoint"):
type(self.model.origin).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
else:
if hasattr(type(self.model.to(nn.Module)), "set_activation_checkpoint"):
type(self.model.to(nn.Module)).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).activation_checkpointing = True
def set_pipeline_stage_id(self):
if hasattr(self.model, "origin"):
if hasattr(type(self.model.origin), "set_pipeline_stage_id"):
type(self.model.origin).set_pipeline_stage_id(self.model)
else:
if hasattr(type(self.model.to(nn.Module)), "set_pipeline_stage_id"):
type(self.model.to(nn.Module)).set_pipeline_stage_id(self.model)
## Introduction
Here are the Weight Loaders currently supported in LiBai. You can use them to load the models in LiBai and the models stored on the huggingface.
## Weight Loader On LiBai
- [BERT Loader](./bert_loader.py)
- [RoBERTa Loader](./roberta_loader.py)
- [GPT2 Loader](./gpt_loader.py)
- [MT5 Loader](../../../../projects/MT5/utils/mt5_loader.py)
- [SWIN Loader](./swin_loader.py)
- [SWIN2 Loader](./swinv2_loader.py)
- [VIT Loader](./vit_loader.py)
## How To Use
We can easily load pretrained BERT as following:
```python
import libai
from libai.models.utils import BertLoaderHuggerFace, BertLoaderLiBai
from configs.common.models.bert import cfg
# load huggingface weight
loader = BertLoaderHuggerFace(
model=libai.models.BertModel,
libai_cfg=cfg,
pretrained_model_path="path/to/huggingface_pretrained_model_directory",
hidden_dropout_prob=0,
apply_residual_post_layernorm=True
)
bert = loader.load()
# load libai weight
loader = BertLoaderLiBai(
model=libai.models.BertModel,
libai_cfg=cfg,
pretrained_model_path='path/to/libai_pretrained_model_directory',
hidden_dropout_prob=0,
)
bert = loader.load()
```
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import logging
import os
import omegaconf
import oneflow as flow
from termcolor import colored
import libai.utils.distributed as dist
from libai.config import LazyCall
from libai.models.build import build_model
logger = logging.getLogger(__name__)
WEIGHTS_NAME_PT = "pytorch_model.bin"
CONFIG_NAME = "config.json"
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
"""load state dict into model
Args:
model_to_load (nn.Module): Model to be loaded.
state_dict (OrderedDict): State dict of pretrained model.
start_prefix (str): Start prefix.
Returns:
list: error message about loading.
"""
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model_to_load, prefix=start_prefix)
return error_msgs
class ModelLoader(object):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
"""Class used to load the [`transformers`](https://huggingface.co/models) pretrained model
or `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
self.model = model
self.libai_cfg = libai_cfg
self.pretrained_model_path = pretrained_model_path
self.kwargs = kwargs
self.output_loading_info = kwargs.pop("output_loading_info", False)
def _state_dict_to_global(self, flow_state_dict=None, mode="libai"):
"""Tensor in OneFlow state dict to global according to model's sbp and placement.
Args:
flow_state_dict (OrderedDict): State dict of OneFlow's pretrained model.
"""
assert mode in ["libai", "pytorch"], f"not support for mode {mode}"
if mode == "libai" or dist.is_main_process():
prefix = self.base_model_prefix_2
# Checkpoint
has_prefix_module = any(
s.startswith(self.base_model_prefix_2) for s in flow_state_dict.keys()
)
# Module
expects_prefix_module = any(
s.startswith(prefix) for s in self.model.state_dict().keys()
)
start_prefix = "" if has_prefix_module else prefix + "."
loaded_keys = [start_prefix + key for key in flow_state_dict.keys()]
else:
prefix, has_prefix_module, expects_prefix_module, loaded_keys = [None] * 4
flow_state_dict = collections.OrderedDict()
prefix = dist.broadcast_py_object(prefix, src=0)
has_prefix_module = dist.broadcast_py_object(has_prefix_module, src=0)
expects_prefix_module = dist.broadcast_py_object(expects_prefix_module, src=0)
loaded_keys = dist.broadcast_py_object(loaded_keys, src=0)
# to global
for key, value in self.model.state_dict().items():
if not expects_prefix_module:
key = prefix + "." + key
if key in loaded_keys:
if not has_prefix_module:
key = ".".join(key.split(".")[1:])
if mode == "pytorch":
flow_state_dict[key] = flow.to_global(
flow_state_dict[key] if dist.is_main_process() else flow.Tensor(None),
sbp=flow.sbp.broadcast,
placement=flow.placement("cpu", ranks=[0]),
)
flow_state_dict[key] = flow.to_global(
flow_state_dict[key],
sbp=value.sbp,
placement=flow.placement("cpu", ranks=list(value.placement.ranks)),
)
return flow_state_dict
def _load_pretrained_model(
self,
model,
state_dict,
pretrained_model_path,
ignore_mismatched_sizes=False,
):
"""Load pretrained model.
Args:
model (libai.models): The model to be loaded.
state_dict (OrderedDict): state dict.
loaded_keys (list): keys of state dict.
pretrained_model_path (str): pretrained modelE path.
ignore_mismatched_sizes (bool):
Whether or not to raise an error if some of the weights
from the checkpoint do not have the same size as the
weights of the model, defaults to `False`.
"""
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = self.base_model_prefix_2
loaded_keys = state_dict.keys()
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [
".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys
]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
start_prefix = ""
model_to_load = model
if (
len(self.base_model_prefix_2) > 0
and not hasattr(model, self.base_model_prefix_2)
and has_prefix_module
):
start_prefix = self.base_model_prefix_2 + "."
if (
len(self.base_model_prefix_2) > 0
and hasattr(model, self.base_model_prefix_2)
and not has_prefix_module
):
model_to_load = getattr(model, self.base_model_prefix_2)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError("The state dict of the model you are loading is corrupted.")
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(
checkpoint_key,
state_dict[checkpoint_key].shape,
model_state_dict[model_key].shape,
)
)
del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if dist.get_local_rank() == 0:
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(
f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_path} "
"were not used when "
f"initializing {model.__class__.__name__}:\n {unexpected_keys}\n"
)
else:
logger.info(
f"All model checkpoint weights were used when initializing "
f"{model.__class__.__name__}.\n"
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized "
f"from the model checkpoint at {pretrained_model_path}:\n "
f"{missing_keys} \n"
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized "
f"from the model checkpoint at {pretrained_model_path}.\n"
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2}"
"in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized"
f"from the model checkpoint at {pretrained_model_path} "
f"and are newly initialized because the shapes did not"
f"match:\n{mismatched_warning}\n"
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
class ModelLoaderLiBai(ModelLoader):
"""Class used to load `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = None # prefix in LiBai
def _load_flow_state_dict(self, state_dict_file):
# load oneflow_model
state_dict = flow.load(state_dict_file, global_src_rank=0)
return state_dict
def load(self):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from libai.config.configs.common.models.bert import cfg
>>> from model_utils import BertLoaderLiBai
>>> loder = BertLoaderLiBai(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loder.load()
"""
if dist.is_main_process():
assert os.path.isdir(
self.pretrained_model_path
), f"{self.pretrained_model_path} must be a directory"
flow_state_dict = self._load_flow_state_dict(self.pretrained_model_path)
# Instance model
if isinstance(self.model, omegaconf.dictconfig.DictConfig):
self.model.cfg = self.libai_cfg
self.model = build_model(self.model)
else:
self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))
# State_dict to global
self._state_dict_to_global(flow_state_dict, mode="libai")
# Load
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
error_msgs,
) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
if self.output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
class ModelLoaderHuggerFace(ModelLoader):
"""Class used to load the [`transformers`](https://huggingface.co/models)
pretrained model.
"""
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_1 = None # prefix in Transformers
self.base_model_prefix_2 = None # prefix in LiBai
self.origin_libai_cfg = copy.deepcopy(self.libai_cfg)
self.changed_keys = set() # Store the changed configuration
def _convert_tensor(self, tensor):
"""Convert PyTorch tensor to OneFlow tensor.
Args:
tensor (torch.Tensor): The source tensor.
Returns:
flow.Tensor: The target tensor.
"""
tensor = tensor.float()
return flow.Tensor(tensor.detach().cpu().numpy())
def _convert_tensors(self, torch_state_dict):
for k, v in torch_state_dict.items():
torch_state_dict[k] = self._convert_tensor(v)
return torch_state_dict
def _fix_key(self, state_dict):
"""Fix the key in state dict: Convert "gamma" to "weight" and "beta" to "bias".
Args:
state_dict (OrderedDict): state dict of pretrained model.
Returns:
OrderedDict: State dict after fix key.
"""
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return state_dict
def _fix_qkv_ordering(
self, qkv, head_size, num_heads, hidden_size=None, checkpoint_version=0.0
):
# TODO(xzp): Different versions checkpoint
hidden_size = (head_size * num_heads) if hidden_size is None else hidden_size
num_of_qkv = qkv.shape[0] // (head_size * num_heads)
mode = "weight" if qkv.ndim > 1 else "bias"
if mode == "weight":
qkv = qkv.view([num_of_qkv, num_heads, head_size, hidden_size])
qkv = (
qkv.permute(1, 0, 2, 3)
.contiguous()
.view(num_of_qkv * head_size * num_heads, hidden_size)
)
elif mode == "bias":
qkv = qkv.view(num_of_qkv, num_heads, head_size)
qkv = qkv.permute(1, 0, 2).contiguous().view(-1)
return qkv
def _convert_state_dict(self, flow_state_dict, cfg):
"""A function used to convert the checkpoint file of Huggingface to LiBai.
Args:
torch_state_dict (OrderedDict): torch state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
raise NotImplementedError("_convert_state_dict not implemented")
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
raise NotImplementedError("_load_config_from_json not implemented")
def _load_torch_state_dict(self, state_dict_file):
try:
import torch
except ImportError:
raise ImportError("Load torch state dict need torch.")
# load pytorch_model.bin
state_dict = torch.load(state_dict_file, map_location="cpu")
return state_dict
def _update_cfg(self, keys_libai, value_target):
"""Update the libai_cfg according to target_cfg.
Args:
keys_libai (str): The key of libai_cfg.
value_target (int | float): The value of target_cfg.
"""
if keys_libai not in self.libai_cfg.keys():
return
if self.libai_cfg[keys_libai] != value_target:
self.libai_cfg[keys_libai] = value_target
def _update_cfg_log(self):
if dist.get_local_rank() == 0:
for key in sorted(self.libai_cfg):
if self.origin_libai_cfg[key] == self.libai_cfg[key]:
continue
self.changed_keys.add(key)
temp_key = colored(key, "yellow")
logger.info(
f"changed libai model cfg {temp_key} : "
f"{self.origin_libai_cfg[key]} -> {self.libai_cfg[key]} "
)
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)
if dist.get_pipeline_parallel_size() > 1:
logger.warning(
colored(
"If you use pipeline parallel, please "
"confirm the setting of `train.dist.pipeline_num_layers` \n",
"red",
)
)
def load(self):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from configs.common.models.bert import cfg
>>> from libai.models.utils import BertLoaderHugger
>>> loader = BertLoaderHugger(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loader.load()
"""
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
if os.path.isfile(os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)):
model_file = os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME_PT} found"
f"in directory {self.pretrained_model_path}."
)
# config file
if os.path.isfile(os.path.join(self.pretrained_model_path, CONFIG_NAME)):
config_file = os.path.join(self.pretrained_model_path, CONFIG_NAME)
# Load config and update config.
self._load_config_from_json(config_file)
else:
import warnings
warnings.warn(
f"Error no file named {CONFIG_NAME} found in directory"
f"{self.pretrained_model_path}",
RuntimeWarning,
)
else:
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")
logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_file)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
flow_state_dict = self._convert_state_dict(torch_state_dict, self.libai_cfg)
else:
flow_state_dict = None
self.libai_cfg = dist.broadcast_py_object(self.libai_cfg, src=0)
# Instance model
logger.info("building LiBai model...")
if isinstance(self.model, omegaconf.dictconfig.DictConfig):
self.model.cfg = self.libai_cfg
self.model = build_model(self.model)
else:
self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))
# State_dict to global
logger.info("transfering state_dict local to global...")
flow_state_dict = self._state_dict_to_global(flow_state_dict, mode="pytorch")
logger.info("loading model weights into LiBai...")
# Load
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
error_msgs,
) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
if self.output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class BertLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is BERT's prefix in Transformers.
base_model_prefix_2 is BERT's prefix in LiBai."""
self.base_model_prefix_1 = "bert"
self.base_model_prefix_2 = "bert"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
layers = cfg.get("hidden_layers")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix = "bert." if has_prefix else ""
index_idx = 3 if has_prefix else 2
qkv_idx = 6 if has_prefix else 5
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert bert's embedding layers
if "embeddings" in key:
if "word_embeddings" in key:
new_key = key.replace("word_embeddings", "vocab_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "token_type_embeddings" in key:
new_key = key.replace("token_type_embeddings", "tokentype_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.weight" in key:
new_key = prefix + "encoders.0.input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.bias" in key:
new_key = prefix + "encoders.0.input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict[key]
# Convert bert's attention layers
elif "attention" in key:
if "self" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key.replace(key.split(".")[qkv_idx], "query").replace(
key.split(".")[qkv_idx + 1], "weight"
)
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads)
qkv_b = self._fix_qkv_ordering(qkv_b, head_size, num_heads)
new_key = (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = prefix + "encoders." + index + ".self_attention.query_key_value.bias"
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
index = key.split(".")[index_idx]
if "dense" in key:
if "weight" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm" in key:
if "weight" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert bert's intermediate layers
elif "intermediate" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
if "weight" in key:
w = key
b = key.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert bert's output layers
elif "output" in key:
index = key.split(".")[index_idx]
if "dense.weight" in key:
if (
prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "LayerNorm.weight" in key:
if (
prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
if index == str(layers - 1):
new_key = prefix + "final_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
continue
new_key = prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert bert's pooler layers
elif "pooler" in key:
if "weight" in key:
new_key = prefix + "pooler.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "pooler.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert cls_head layers
elif "cls" in key:
if "predictions.bias" in key:
new_key = "cls_head.lm_logits.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "dense.weight" in key:
new_key = "cls_head.predictions.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "dense.bias" in key:
new_key = "cls_head.predictions.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.weight" in key:
new_key = "cls_head.predictions.layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.bias" in key:
new_key = "cls_head.predictions.layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "seq_relationship" in key:
new_key = key.replace("cls", "cls_head")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("vocab_size", cfg_dict["vocab_size"])
self._update_cfg("hidden_size", cfg_dict["hidden_size"])
self._update_cfg("hidden_layers", cfg_dict["num_hidden_layers"])
self._update_cfg("num_attention_heads", cfg_dict["num_attention_heads"])
self._update_cfg("intermediate_size", cfg_dict["intermediate_size"])
self._update_cfg("hidden_dropout_prob", cfg_dict["hidden_dropout_prob"])
self._update_cfg("attention_probs_dropout_prob", cfg_dict["attention_probs_dropout_prob"])
self._update_cfg("max_position_embeddings", cfg_dict["max_position_embeddings"])
self._update_cfg("num_tokentypes", cfg_dict["type_vocab_size"])
self._update_cfg("initializer_range", cfg_dict["initializer_range"])
self._update_cfg("layernorm_eps", cfg_dict["layer_norm_eps"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
# use original BERT residual connection ordering
self.libai_cfg.apply_residual_post_layernorm = True
self._update_cfg_log()
class BertLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "bert"
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class GPT2LoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is GPT's prefix in Transformers.
base_model_prefix_2 is GPT's prefix in LiBai."""
self.base_model_prefix_1 = "transformer"
self.base_model_prefix_2 = "GPT_model"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
old_keys = list(oneflow_state_dict.keys())
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix1 = self.base_model_prefix_1 + "." if has_prefix else ""
prefix2 = "GPT_model.transformer."
layer_idx = 2 if has_prefix else 1
# Convert Embedding layers.
new_key = "GPT_model.embeddings.token_embeddings.weight"
old_keys.remove(prefix1 + "wte.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "wte.weight")
new_key = "GPT_model.embeddings.position_embeddings.weight"
old_keys.remove(prefix1 + "wpe.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "wpe.weight")
for key in old_keys:
keys = key.split(".")
if layer_idx >= len(keys):
continue
layer = keys[layer_idx]
# Convert transformer layers.
if "h." in key:
if "ln_1" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".input_layernorm.weight"
else:
new_key = prefix2 + "layers." + layer + ".input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "ln_2" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".post_attention_layernorm.weight"
else:
new_key = prefix2 + "layers." + layer + ".post_attention_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "attn" in key:
if "c_attn" in key:
if "weight" in key:
new_key = (
prefix2
+ "layers."
+ layer
+ ".self_attention.query_key_value.weight"
)
else:
new_key = (
prefix2 + "layers." + layer + ".self_attention.query_key_value.bias"
)
qkv = oneflow_state_dict.pop(key)
if qkv.ndim > 1:
qkv = qkv.transpose(1, 0)
qkv = self._fix_qkv_ordering(qkv, head_size, num_heads)
oneflow_state_dict[new_key] = qkv
elif "c_proj" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".self_attention.dense.weight"
elif "bias" in key:
new_key = prefix2 + "layers." + layer + ".self_attention.dense.bias"
value = oneflow_state_dict.pop(key)
if value.ndim > 1:
value = value.transpose(1, 0)
oneflow_state_dict[new_key] = value
elif "mlp" in key:
if "c_fc" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_h_to_4h.weight"
elif "bias" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_h_to_4h.bias"
value = oneflow_state_dict.pop(key)
if value.ndim > 1:
value = value.transpose(1, 0)
oneflow_state_dict[new_key] = value
elif "c_proj" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_4h_to_h.weight"
elif "bias" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_4h_to_h.bias"
value = oneflow_state_dict.pop(key)
if value.ndim > 1:
value = value.transpose(1, 0)
oneflow_state_dict[new_key] = value
elif "ln_f" in key:
if "weight" in key:
new_key = prefix2 + "layernorm_f.weight"
elif "bias" in key:
new_key = prefix2 + "layernorm_f.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("hidden_layers", cfg_dict["n_layer"])
self._update_cfg("hidden_size", cfg_dict["n_embd"])
self._update_cfg("num_attention_heads", cfg_dict["n_head"])
self._update_cfg("max_seq_length", cfg_dict["n_positions"])
self._update_cfg("embedding_dropout_prob", cfg_dict["embd_pdrop"])
self._update_cfg("attention_dropout_prob", cfg_dict["attn_pdrop"])
self._update_cfg("output_dropout_prob", cfg_dict["resid_pdrop"])
self._update_cfg("layernorm_epsilon", cfg_dict["layer_norm_epsilon"])
self._update_cfg("vocab_size", cfg_dict["vocab_size"])
self._update_cfg("initializer_range", cfg_dict["initializer_range"])
self._update_cfg(
"ffn_hidden_size",
cfg_dict.get("n_inner")
if cfg_dict.get("n_inner") is not None
else 4 * self.libai_cfg["hidden_size"],
)
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class GPT2LoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "GPT_model"
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from .bert_loader import BertLoaderHuggerFace, BertLoaderLiBai
class RobertaLoaderHuggerFace(BertLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is RoBERTa's prefix in Transformers,
base_model_prefix_2 is RoBERTa's prefix in LiBai."""
self.base_model_prefix_1 = "roberta"
self.base_model_prefix_2 = "roberta"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
layers = cfg.get("hidden_layers")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix = "roberta." if has_prefix else ""
index_idx = 3 if has_prefix else 2
qkv_idx = 6 if has_prefix else 5
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert roberta's embedding layers
if "embeddings" in key:
if "word_embeddings" in key:
new_key = key.replace("word_embeddings", "vocab_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "token_type_embeddings" in key:
new_key = key.replace("token_type_embeddings", "tokentype_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.weight" in key:
new_key = prefix + "encoders.0.input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.bias" in key:
new_key = prefix + "encoders.0.input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict[key]
# Convert roberta's attention layers
elif "attention" in key:
if "self" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key.replace(key.split(".")[qkv_idx], "query").replace(
key.split(".")[qkv_idx + 1], "weight"
)
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads)
qkv_b = self._fix_qkv_ordering(qkv_b, head_size, num_heads)
new_key = (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = prefix + "encoders." + index + ".self_attention.query_key_value.bias"
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
index = key.split(".")[index_idx]
if "dense" in key:
if "weight" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm" in key:
if "weight" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert roberta's intermediate layers
elif "intermediate" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
if "weight" in key:
w = key
b = key.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert roberta's output layers
elif "output" in key:
index = key.split(".")[index_idx]
if "dense.weight" in key:
if (
prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "LayerNorm.weight" in key:
if (
prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
if index == str(layers - 1):
new_key = prefix + "final_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
continue
new_key = prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert roberta's pooler layers
elif "pooler" in key:
if "weight" in key:
new_key = prefix + "pooler.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "pooler.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert lm_head layers
elif "lm_head" in key:
if "layer_norm.weight" in key:
new_key = "lm_head.layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layer_norm.bias" in key:
new_key = "lm_head.layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "seq_relationship" in key:
new_key = key.replace("cls", "cls_head")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "lm_head.bias" in key:
new_key = "lm_head.lm_logits.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
class RobertaLoaderLiBai(BertLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "roberta"
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class SwinLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is SWIN's prefix in Transformers.
base_model_prefix_2 is SWIN's prefix in LiBai."""
self.base_model_prefix_1 = "swin"
self.base_model_prefix_2 = ""
def _convert_state_dict(self, flow_state_dict, cfg=None):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
index_idx_1 = 3 if has_prefix else 2
index_idx_2 = 5 if has_prefix else 4
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert swin's embedding layers
if "embeddings" in key:
if "patch_embeddings.projection" in key:
if "weight" in key:
new_key = "patch_embed.proj.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "patch_embed.proj.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if "weight" in key:
new_key = "patch_embed.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "patch_embed.norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swin's layernorm layers
elif "layernorm_before" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layernorm_after" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swin's attention layers
elif "attention" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "self" in key:
if (
"relative_position_bias_table" in key
): # convert relative_position_bias_table but not index
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_bias_table"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "relative_position_index" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_index"
)
oneflow_state_dict.pop(key)
else:
if (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
if "dense" in key:
if "weight" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "intermediate" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = key.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "output" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "dense.weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "downsample" in key:
index_layer = key.split(".")[index_idx_1]
if "reduction.weight" in key:
new_key = "layers." + index_layer + ".downsample.reduction.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if (
"layers." + index_layer + ".downsample.norm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = "layers." + index_layer + ".downsample.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "layernorm" in key:
if "weight" in key:
new_key = "norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "classifier" in key:
if "weight" in key:
new_key = "head.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "head.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("img_size", cfg_dict["image_size"])
self._update_cfg("patch_size", cfg_dict["patch_size"])
self._update_cfg("embed_dim", cfg_dict["embed_dim"])
self._update_cfg("depths", cfg_dict["depths"])
self._update_cfg("num_heads", cfg_dict["num_heads"])
self._update_cfg("window_size", cfg_dict["window_size"])
self._update_cfg("mlp_ratio", cfg_dict["mlp_ratio"])
self._update_cfg("qkv_bias", cfg_dict["qkv_bias"])
self._update_cfg("drop_path_rate", cfg_dict["drop_path_rate"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class SwinLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = ""
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class SwinV2LoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is SWINV2's prefix in Transformers.
base_model_prefix_2 is SWINV2's prefix in LiBai."""
self.base_model_prefix_1 = "swinv2"
self.base_model_prefix_2 = ""
def _convert_state_dict(self, flow_state_dict, cfg=None):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
index_idx_1 = 3 if has_prefix else 2
index_idx_2 = 5 if has_prefix else 4
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert swinv2's embedding layers
if "embeddings" in key:
if "patch_embeddings.projection" in key:
if "weight" in key:
new_key = "patch_embed.proj.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = "patch_embed.proj.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if "weight" in key:
new_key = "patch_embed.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = "patch_embed.norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swinv2's layernorm layers
elif "layernorm_before" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layernorm_after" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swinv2's attention layers
elif "attention" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "self" in key:
if (
"relative_position_bias_table" in key
): # convert relative_position_bias_table but not index
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_bias_table"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "relative_position_index" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_index"
)
oneflow_state_dict.pop(key)
elif "continuous_position_bias_mlp" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.cpb_mlp"
+ ".0.weight"
) in oneflow_state_dict.keys():
continue
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.cpb_mlp"
)
m_1_w = key
m_1_b = key.replace(".0.weight", ".0.bias")
m_2_w = key.replace(".0.weight", ".2.weight")
oneflow_state_dict[new_key + ".0.weight"] = oneflow_state_dict.pop(m_1_w)
oneflow_state_dict[new_key + ".0.bias"] = oneflow_state_dict.pop(m_1_b)
oneflow_state_dict[new_key + ".2.weight"] = oneflow_state_dict.pop(m_2_w)
elif "logit_scale" in key:
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.logit_scale"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)[None, ...]
else:
if (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.q_bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(q_b)
new_key = new_key.replace("q_bias", "v_bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(v_b)
elif "output" in key:
if "dense" in key:
if "weight" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "intermediate" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = key.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "output" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "dense.weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "downsample" in key:
index_layer = key.split(".")[index_idx_1]
if "reduction.weight" in key:
new_key = "layers." + index_layer + ".downsample.reduction.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if (
"layers." + index_layer + ".downsample.norm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = "layers." + index_layer + ".downsample.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "layernorm" in key:
if "weight" in key:
new_key = "norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "classifier" in key:
if "weight" in key:
new_key = "head.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "head.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("img_size", cfg_dict["image_size"])
self._update_cfg("patch_size", cfg_dict["patch_size"])
self._update_cfg("embed_dim", cfg_dict["embed_dim"])
self._update_cfg("depths", cfg_dict["depths"])
self._update_cfg("num_heads", cfg_dict["num_heads"])
self._update_cfg("window_size", cfg_dict["window_size"])
self._update_cfg("mlp_ratio", cfg_dict["mlp_ratio"])
self._update_cfg("qkv_bias", cfg_dict["qkv_bias"])
self._update_cfg("drop_path_rate", cfg_dict["drop_path_rate"])
self._update_cfg("pretrained_window_sizes", cfg_dict["pretrained_window_sizes"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class SwinV2LoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment