Commit 3b355d3f authored by yuguo960516's avatar yuguo960516
Browse files

gpt2

parent fd158e88
Pipeline #143 canceled with stages
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
from libai.config import configurable
from libai.layers import (
Embedding,
LayerNorm,
Linear,
LMLogits,
ParallelCrossEntropyLoss,
TransformerLayer,
VocabEmbedding,
build_activation,
)
from libai.utils import distributed as dist
from .bert_model import BertEmbeddings, BertExtendedAttnMask, BertModel, BertPooler
from .utils import init_method_normal
class RobertaExtendedAttnMask(BertExtendedAttnMask):
"""
Same as BertExtendedAttnMask.
"""
class RobertaEmbeddings(BertEmbeddings):
"""
Same as BertEmbeddings with a tiny tweak for vocab_embeddings and position_embeddings.
"""
def __init__(
self,
vocab_size,
hidden_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes=0,
pad_token_id=1,
init_method=nn.init.xavier_normal_,
amp_enabled=False,
):
super().__init__(
vocab_size,
hidden_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes=num_tokentypes,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.pad_token_id = pad_token_id
self.vocab_embeddings = VocabEmbedding(
vocab_size,
hidden_size,
init_method=init_method,
amp_enabled=amp_enabled,
padding_idx=pad_token_id,
)
self.position_embeddings = Embedding(
max_sequence_length,
hidden_size,
init_method=init_method,
amp_enabled=amp_enabled,
padding_idx=pad_token_id,
)
if num_tokentypes > 0:
self.tokentype_embeddings = Embedding(
num_tokentypes, hidden_size, init_method=init_method, amp_enabled=amp_enabled
)
self.tokentype_ids = flow.zeros(
1,
max_sequence_length,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
else:
self.tokentype_embeddings = None
def forward(self, input_ids, tokentype_ids=None, position_ids=None):
seq_length = input_ids.size()[1]
word_embeddings = self.vocab_embeddings(input_ids)
if position_ids is None:
position_ids = self.create_position_ids_from_input_ids(input_ids, self.pad_token_id)
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
if self.tokentype_embeddings is not None:
if tokentype_ids is None:
tokentype_ids = (
self.tokentype_ids[:, :seq_length]
.expand_as(input_ids)
.to_global(sbp=input_ids.sbp)
)
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
embeddings = self.embedding_dropout(embeddings)
return embeddings
def create_position_ids_from_input_ids(self, input_ids, pad_token_id):
mask = input_ids.ne(pad_token_id).int()
position_ids = (flow.cumsum(mask, dim=1).type_as(mask)) * mask + pad_token_id
position_ids = position_ids.to_global(sbp=input_ids.sbp, placement=input_ids.placement)
return position_ids
class RobertaPooler(BertPooler):
"""
Same as BertPooler.
"""
class RobertaLoss(nn.Module):
def __init__(self):
super().__init__()
self.lm_loss = ParallelCrossEntropyLoss()
def forward(self, lm_output, lm_labels, loss_mask):
lm_labels = lm_labels.to_global(placement=lm_output.placement)
loss_mask = loss_mask.to_global(placement=lm_output.placement)
lm_loss = self.lm_loss(lm_output, lm_labels)
loss_mask = loss_mask.float()
# Change loss_mask.sum() sbp sign from [P, B] -> [B, B]
# because (lm_loss * loss_mask) / loss_mask.sum() cannot accept P / P
denominator = loss_mask.sum().to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
)
masked_lm_loss = flow.sum(lm_loss.view(-1) * loss_mask.view(-1)) / denominator
masked_lm_loss = masked_lm_loss.to_global(
sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast])
)
loss_dict = {"lm_loss": masked_lm_loss}
return loss_dict
class RobertaModel(BertModel):
"""The bare Roberta Model transformer outputting raw hidden-states without
any specific head on top.
Args:
vocab_size (int):
The size of vocabulary file.
hidden_size (int):
The size of hidden states.
hidden_layers (int):
The number of ``TransformerLayer`` in encoder.
num_attention_heads (int):
The number of attention heads for each attention layer of ``TransformerLayer``.
intermediate_size (int):
The size of intermediate layer in feed-forward network for each
``TransformerLayer``.
hidden_dropout_prob (float, optional):
The dropout ratio for the output for each TransformerLayer. Defaults to 0.0.
attention_probs_dropout_prob (float, optional):
The dropout ratio for the output of each attention layer in ``TransformerLayer``.
Defaults to 0.0.
max_position_embeddings (int):
Max sequence length of input, defines the shape of Position Embeddings
in ``RobertaEmbeddings``.
type_vocab_size (int, optional):
Number of segment token indices. Defaults to 2.
add_pooling_layer (bool, optional):
Whether or not averaging or pooling the sequence of hidden-states for the
whole input sequence. Defaults to ``True``.
initializer_range (float, optional):
Sigma of the normal distribution in the initialization method. Defaults to 0.02.
layer_norm_eps (float, optional):
The epsilon of LayerNorm layer. Defaults to 1e-5.
pad_token_id (int, optional):
The token id used for padding. Defaults to 1.
bias_gelu_fusion (bool, optional):
Whether or not to fuse the computing of bias and gelu. Defaults to ``False``.
bias_dropout_fusion (bool, optional):
Whether or not to fuse the computing of dropout and bias. Defaults to ``False``.
scale_mask_softmax_fusion (bool, optional):
Whether to fuse the computing of mask and softmax in attention layers.
Defaults to ``False``.
apply_query_key_layer_scaling (bool, optional):
Whether or not to use layer index related scaling in computing attention scores.
If ``True``, the scaling factor equals to sqrt(d) * (layer_index + 1).
Defaults to ``True``.
apply_residual_post_layernorm (bool, optional):
If set ``True``, use original BERT(Roberta) residual connection ordering
otherwise use Megatron BERT residual connection which is more stable
when scaling model size introduced in https://arxiv.org/pdf/1909.08053.pdf.
Default: ``False``.
amp_enabled (bool, optional):
Whether or not to set fp16 for embedding weight in T5 model. Defaults to ``False``.
"""
@configurable
def __init__(
self,
vocab_size,
hidden_size,
hidden_layers,
num_attention_heads,
intermediate_size,
hidden_dropout_prob,
attention_probs_dropout_prob,
max_position_embeddings,
num_tokentypes=2,
add_pooling_layer=True,
initializer_range=0.02,
layernorm_eps=1e-12,
pad_token_id=1,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
scale_mask_softmax_fusion=True,
apply_query_key_layer_scaling=True,
apply_residual_post_layernorm=False,
amp_enabled=False,
):
super().__init__(
vocab_size,
hidden_size,
hidden_layers,
num_attention_heads,
intermediate_size,
hidden_dropout_prob,
attention_probs_dropout_prob,
max_position_embeddings,
num_tokentypes=num_tokentypes,
add_pooling_layer=add_pooling_layer,
initializer_range=initializer_range,
layernorm_eps=layernorm_eps,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
apply_residual_post_layernorm=apply_residual_post_layernorm,
amp_enabled=amp_enabled,
)
init_method = init_method_normal(initializer_range)
# Embeddings
self.embeddings = RobertaEmbeddings(
vocab_size,
hidden_size,
max_position_embeddings,
hidden_dropout_prob,
num_tokentypes,
pad_token_id,
init_method,
amp_enabled,
)
# Mask generation
self.extended_attn_mask = RobertaExtendedAttnMask()
self.pooler = RobertaPooler(hidden_size, init_method) if add_pooling_layer else None
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"hidden_layers": cfg.hidden_layers,
"num_attention_heads": cfg.num_attention_heads,
"intermediate_size": cfg.intermediate_size,
"hidden_dropout_prob": cfg.hidden_dropout_prob,
"attention_probs_dropout_prob": cfg.attention_probs_dropout_prob,
"max_position_embeddings": cfg.max_position_embeddings,
"num_tokentypes": cfg.num_tokentypes,
"add_pooling_layer": cfg.add_pooling_layer,
"initializer_range": cfg.initializer_range,
"layernorm_eps": cfg.layernorm_eps,
"pad_token_id": cfg.pad_token_id,
"bias_gelu_fusion": cfg.bias_gelu_fusion,
"bias_dropout_fusion": cfg.bias_dropout_fusion,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
"apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling,
"apply_residual_post_layernorm": cfg.apply_residual_post_layernorm,
"amp_enabled": cfg.amp_enabled,
}
class RobertaLMHead(nn.Module):
def __init__(self, vocab_size, hidden_size, init_method, layer_norm_eps):
super().__init__()
self.dense = Linear(
hidden_size,
hidden_size,
bias=True,
parallel="data",
init_method=init_method,
layer_idx=-1,
)
self.activation_func = build_activation("gelu")
self.layernorm = LayerNorm((hidden_size,), eps=layer_norm_eps, layer_idx=-1)
# NOTE(xzp): LMLogits as a decoder:nn.Linear(hidden_size, vocab_size),
# it shares the roberta.word_embeddings.weight
self.lm_logits = LMLogits(vocab_size, bias=True)
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation_func(hidden_states)
hidden_states = hidden_states.to_global(
sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast])
)
hidden_states = self.layernorm(hidden_states)
hidden_states = self.lm_logits(hidden_states, word_embeddings_weight)
return hidden_states
class RobertaPreTrainedModel(nn.Module):
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.roberta.final_layernorm, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
# module.origin can get the original module
if isinstance(module_block.origin, RobertaEmbeddings):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, RobertaExtendedAttnMask):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# `add_pooling_layer` in RobertaForMaskedLM and RobertaForCausalLM.
# default to False.
elif isinstance(module_block.origin, RobertaPooler):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
elif isinstance(module_block.origin, RobertaLMHead):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
# Set the last layernorm stage id
model.roberta.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
for module_block in model.modules():
# module.origin can get the original module
if isinstance(module_block.to(nn.Module), RobertaEmbeddings):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), RobertaExtendedAttnMask):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# `add_pooling_layer` in RobertaForMaskedLM and RobertaForCausalLM.
# default to False.
elif isinstance(module_block.to(nn.Module), RobertaPooler):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
elif isinstance(module_block.to(nn.Module), RobertaLMHead):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
# Set the last layernorm stage id
model.roberta.final_layernorm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
class RobertaForPreTraining(RobertaPreTrainedModel):
def __init__(self, cfg):
super().__init__()
cfg.add_pooling_layer = False
self.roberta = RobertaModel(cfg)
self.lm_head = RobertaLMHead(
cfg.vocab_size,
cfg.hidden_size,
init_method_normal(cfg.initializer_range),
cfg.layernorm_eps,
)
self.loss_fc = RobertaLoss()
def forward(
self,
input_ids,
attention_mask,
tokentype_ids=None,
lm_labels=None,
loss_mask=None,
):
"""
Args:
input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary.
attention_mask (flow.BoolTensor): Mask to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
tokentype_ids (flow.LongTensor, optional): Segment token indices to indicate first
and second portions of the inputs. Indices are selected in `[0, 1]`.
Defaults to None.
labels (flow.LongTensor, optional): Labels for computing the masked
language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`.
Defaults to None.
loss_mask (flow.BoolTensor, optional): Mask to avoid performing loss computing
on ignored tokens. Tokens with indices set to `-1` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Defaults to None.
"""
input_ids = input_ids.to_global(placement=dist.get_layer_placement(0))
attention_mask = attention_mask.to_global(placement=dist.get_layer_placement(0))
tokentype_ids = tokentype_ids.to_global(placement=dist.get_layer_placement(0))
outputs = self.roberta(input_ids, attention_mask, tokentype_ids=tokentype_ids)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, self.roberta.word_embeddings_weight())
if lm_labels is not None:
return self.loss_fc(prediction_scores, lm_labels, loss_mask)
return {"prediction_scores": prediction_scores}
class RobertaForCausalLM(RobertaPreTrainedModel):
def __init__(self, cfg):
super().__init__()
cfg.add_pooling_layer = False
self.roberta = RobertaModel(cfg)
self.lm_head = RobertaLMHead(
cfg.vocab_size,
cfg.hidden_size,
init_method_normal(cfg.initializer_range),
cfg.layernorm_eps,
)
self.loss_fc = RobertaLoss()
def forward(
self,
input_ids,
attention_mask,
tokentype_ids=None,
position_ids=None,
labels=None,
loss_mask=None,
):
"""
Args:
input_ids (flow.LongTensor): Indices of input sequence tokens in vocabulary.
attention_mask (flow.BoolTensor): Mask to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
tokentype_ids (flow.LongTensor, optional): Segment token indices to indicate first
and second portions of the inputs. Indices are selected in `[0, 1]`.
Defaults to None.
position_ids (flow.LongTensor, optional): Indices of positions of each input sequence
tokens in the position embeddings. Defaults to None.
labels (flow.LongTensor, optional): Labels for computing the masked
language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`.
Defaults to None.
loss_mask (flow.BoolTensor, optional): Mask to avoid performing loss computing
on ignored tokens. Tokens with indices set to `-1` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Defaults to None.
"""
outputs = self.roberta(input_ids, attention_mask, position_ids, tokentype_ids)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, self.roberta.word_embeddings_weight())
if labels is not None:
# next-token prediction task, shift prediction_scores and labels by one.
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
shifted_prediction_scores = shifted_prediction_scores.to_global(
sbp=prediction_scores.sbp
)
shifted_labels = labels[:, 1:].contiguous()
shifted_labels = shifted_labels.to_global(sbp=shifted_labels.sbp)
lm_loss = self.loss_fc(shifted_prediction_scores, shifted_labels, loss_mask)
return {"lm_loss": lm_loss}
return {"prediction_scores": prediction_scores}
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn as nn
from flowvision.layers import trunc_normal_
from flowvision.models import to_2tuple
from libai.config.config import configurable
from libai.layers import MLP, DropPath, LayerNorm, Linear
from libai.utils import distributed as dist
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query,key,value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
fused_bias_add_dropout=False,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
flow.zeros(
(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
) # 2*Wh-1 * 2*Ww-1, nH
trunc_normal_(self.relative_position_bias_table, std=0.02)
# get pair-wise relative position index for each token inside the window
coords_h = flow.arange(self.window_size[0])
coords_w = flow.arange(self.window_size[1])
coords = flow.stack(flow.meshgrid(*[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] = (
relative_coords[:, :, 0] + self.window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] = relative_coords[:, :, 1] + self.window_size[1] - 1
relative_coords[:, :, 0] = relative_coords[:, :, 0] * (2 * self.window_size[1] - 1)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer(
"relative_position_index",
relative_position_index.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
self.qkv = Linear(dim, dim * 3, bias=qkv_bias, layer_idx=layer_idx)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = Linear(dim, dim, layer_idx=layer_idx)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.fused_bias_add_dropout = fused_bias_add_dropout
self.p = proj_drop
def forward(self, x, mask):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
# attn = flow.matmul(q, k.transpose(-2, -1))
attn = flow.matmul(q, k, transpose_b=True)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
unsqueeze_relative_position_bias = relative_position_bias.unsqueeze(0)
attn = attn + unsqueeze_relative_position_bias
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = flow.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
if self.fused_bias_add_dropout:
x = flow._C.matmul(x, self.proj.weight, transpose_a=False, transpose_b=True)
x = flow._C.fused_bias_add_dropout(x, self.proj.bias, p=self.p, axis=2)
else:
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LayerNorm,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.layer_idx = layer_idx
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim, layer_idx=layer_idx)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
fused_bias_add_dropout=True,
layer_idx=layer_idx,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim, layer_idx=layer_idx)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
hidden_size=dim,
ffn_hidden_size=mlp_hidden_dim,
output_dropout_prob=drop,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
layer_idx=layer_idx,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = flow.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt = cnt + 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
attn_mask = attn_mask.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = flow.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = flow.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=LayerNorm, layer_idx=0):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = Linear(4 * dim, 2 * dim, bias=False, layer_idx=layer_idx)
self.norm = norm_layer(4 * dim, layer_idx=layer_idx)
self.layer_idx = layer_idx
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = flow.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, layer_idx=0
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
).to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, layer_idx=layer_idx)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
downsample (nn.Module | None, optional): Downsample at the end of the layer. Default: None
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=LayerNorm,
downsample=None,
layer_id_offset=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.layer_id_offset = layer_id_offset
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
layer_idx=layer_id_offset + i,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=norm_layer,
layer_idx=layer_id_offset + depth - 1,
)
else:
self.downsample = None
def forward(self, x):
layer_idx = self.layer_id_offset
for i in range(len(self.blocks)):
x = x.to_global(placement=dist.get_layer_placement(layer_idx))
x = self.blocks[i](x)
layer_idx += 1
if self.downsample is not None:
x = self.downsample(x)
return x
class SwinTransformer(nn.Module):
"""Swin Transformer in LiBai.
LiBai implement of:
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/pdf/2103.14030>`_
Args:
img_size (int, tuple(int)): Input image size. Default 224
patch_size (int, tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: libai.layers.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
loss_func (callable, optional): Loss function for computing the total loss
between logits and labels
"""
@configurable
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=LayerNorm,
ape=False,
patch_norm=True,
loss_func=None,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
layer_idx=0,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
flow.zeros(1, num_patches, embed_dim),
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
layer_id_offset = 0
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
layer_id_offset=layer_id_offset,
)
layer_id_offset += depths[i_layer]
self.layers.append(layer)
self.norm = norm_layer(self.num_features, layer_idx=-1)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
Linear(self.num_features, num_classes, layer_idx=-1)
if num_classes > 0
else nn.Identity()
)
# Loss func
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@classmethod
def from_config(cls, cfg):
return {
"img_size": cfg.img_size,
"patch_size": cfg.patch_size,
"in_chans": cfg.in_chans,
"num_classes": cfg.num_classes,
"embed_dim": cfg.embed_dim,
"depths": cfg.depths,
"num_heads": cfg.num_heads,
"window_size": cfg.window_size,
"mlp_ratio": cfg.mlp_ratio,
"qkv_bias": cfg.qkv_bias,
"qk_scale": cfg.qk_scale,
"drop_rate": cfg.drop_rate,
"drop_path_rate": cfg.drop_path_rate,
"ape": cfg.ape,
"patch_norm": cfg.patch_norm,
"loss_func": cfg.loss_func,
}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = flow.flatten(x, 1)
return x
def forward(self, images, labels=None):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x = self.forward_features(images)
x = self.head(x)
if labels is not None and self.training:
losses = self.loss_func(x, labels)
return {"losses": losses}
else:
return {"prediction_scores": x}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.patch_embed, "config"):
# Old API in OneFlow 0.8
model.patch_embed.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, PatchMerging):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
model.patch_embed.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), PatchMerging):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
if hasattr(module_block, "origin"):
# Old API in OneFlow 0.8
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).activation_checkpointing = True
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
from flowvision.layers import trunc_normal_
from flowvision.models import to_2tuple
from libai.config.config import configurable
from libai.layers import MLP, DropPath, LayerNorm, Linear
from libai.utils import distributed as dist
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value.
Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.0,
proj_drop=0.0,
pretrained_window_size=[0, 0],
fused_bias_add_dropout=False,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.fused_bias_add_dropout = fused_bias_add_dropout
self.num_heads = num_heads
self.layer_idx = layer_idx
self.p = proj_drop
self.logit_scale = nn.Parameter(
flow.log(
10
* flow.ones(
1,
num_heads,
1,
1,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
),
requires_grad=True,
)
# NOTE: generate meta network, using mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
Linear(2, 512, bias=True, layer_idx=layer_idx),
nn.ReLU(inplace=True),
Linear(512, num_heads, bias=False, layer_idx=layer_idx),
).to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
# NOTE: get relative_coords_table
relative_coords_h = flow.arange(
-(self.window_size[0] - 1), self.window_size[0], dtype=flow.float32
)
relative_coords_w = flow.arange(
-(self.window_size[1] - 1), self.window_size[1], dtype=flow.float32
)
relative_coords_table = (
flow.stack(flow.meshgrid(*[relative_coords_h, relative_coords_w]))
.permute(1, 2, 0)
.contiguous()
.unsqueeze(0)
) # 1, 2*Wh-1, 2*Ww-1, 2
# NOTE: For any relative coordinate, constrain it to -8~8 (window size)
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] = relative_coords_table[:, :, :, 0] / (
pretrained_window_size[0] - 1
)
relative_coords_table[:, :, :, 1] = relative_coords_table[:, :, :, 1] / (
pretrained_window_size[1] - 1
)
else:
relative_coords_table[:, :, :, 0] = relative_coords_table[:, :, :, 0] / (
self.window_size[0] - 1
)
relative_coords_table[:, :, :, 1] = relative_coords_table[:, :, :, 1] / (
self.window_size[1] - 1
)
relative_coords_table = relative_coords_table * 8
# NOTE: y=sign(x)*log(|x|+1)
relative_coords_table = (
flow.sign(relative_coords_table)
* flow.log2(flow.abs(relative_coords_table) + 1.0)
/ math.log2(8.0)
)
self.register_buffer(
"relative_coords_table",
relative_coords_table.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
# NOTE: get pair-wise relative position index for each token inside the window
coords_h = flow.arange(self.window_size[0])
coords_w = flow.arange(self.window_size[1])
coords = flow.stack(flow.meshgrid(*[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] = (
relative_coords[:, :, 0] + self.window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] = relative_coords[:, :, 1] + self.window_size[1] - 1
relative_coords[:, :, 0] = relative_coords[:, :, 0] * (2 * self.window_size[1] - 1)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer(
"relative_position_index",
relative_position_index.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
self.qkv = Linear(dim, dim * 3, bias=False, layer_idx=layer_idx)
if qkv_bias:
self.q_bias = nn.Parameter(
flow.zeros(
dim,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
self.v_bias = nn.Parameter(
flow.zeros(
dim,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = Linear(dim, dim, layer_idx=layer_idx)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = flow.concat(
[
self.q_bias,
flow.zeros(
self.v_bias.shape,
requires_grad=False,
placement=dist.get_layer_placement(
self.layer_idx, device_type=self.v_bias.placement.type
),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
self.v_bias,
],
dim=0,
)
qkv = self.qkv(x) + qkv_bias.unsqueeze(0).unsqueeze(0)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# NOTE: cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
# NOTE: a learnable scalar
logit_scale = flow.clamp(self.logit_scale, min=-1e6, max=math.log(1.0 / 0.01)).exp()
attn = attn * logit_scale
# NOTE: use relative_coords_table and meta network to generate relative_position_bias
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
-1, self.num_heads
)
relative_position_bias = relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
# NOTE: constrained to a range of -16~16
relative_position_bias = 16 * flow.sigmoid(relative_position_bias).unsqueeze(0)
attn = attn + relative_position_bias
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
if self.fused_bias_add_dropout:
x = flow._C.matmul(x, self.proj.weight, transpose_a=False, transpose_b=True)
x = flow._C.fused_bias_add_dropout(x, self.proj.bias, p=self.p, axis=2)
else:
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LayerNorm,
pretrained_window_size=0,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.layer_idx = layer_idx
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim, layer_idx=layer_idx)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size),
fused_bias_add_dropout=True,
layer_idx=layer_idx,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim, layer_idx=layer_idx)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
hidden_size=dim,
ffn_hidden_size=mlp_hidden_dim,
output_dropout_prob=drop,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
layer_idx=layer_idx,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = flow.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt = cnt + 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = (
attn_mask.masked_fill(attn_mask != 0, float(-100.0))
.masked_fill(attn_mask == 0, float(0.0))
.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = flow.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = flow.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# NOTE: res-post-norm
x = shortcut + self.drop_path(self.norm1(x))
# NOTE: res-post-norm
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=LayerNorm, layer_idx=0):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = Linear(4 * dim, 2 * dim, bias=False, layer_idx=layer_idx)
# NOTE: swinv2-> 2*dim, swin-> 4*dim
self.norm = norm_layer(2 * dim, layer_idx=layer_idx)
self.layer_idx = layer_idx
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = flow.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
# NOTE: post-res-norm, a change that swin-v2 compared to swin
x = self.reduction(x)
x = self.norm(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
Default: None
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=LayerNorm,
downsample=None,
pretrained_window_size=0,
layer_id_offset=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.layer_id_offset = layer_id_offset
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size,
layer_idx=layer_id_offset + i,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=norm_layer,
layer_idx=layer_id_offset + depth - 1,
)
else:
self.downsample = None
def forward(self, x):
layer_idx = self.layer_id_offset
for blk in self.blocks:
x = x.to_global(
placement=dist.get_layer_placement(layer_idx, device_type=x.placement.type)
)
x = blk(x)
layer_idx += 1
if self.downsample is not None:
x = self.downsample(x)
return x
def _init_respostnorm(self):
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, layer_idx=0
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
).to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model \
({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
class SwinTransformerV2(nn.Module):
r"""Swin Transformer
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
"""
@configurable
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=LayerNorm,
ape=False,
patch_norm=True,
pretrained_window_sizes=[0, 0, 0, 0],
loss_func=None,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
layer_idx=0,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
flow.zeros(
1,
num_patches,
embed_dim,
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
layer_id_offset = 0
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer],
layer_id_offset=layer_id_offset,
)
layer_id_offset += depths[i_layer]
self.layers.append(layer)
self.norm = norm_layer(self.num_features, layer_idx=-1)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
Linear(self.num_features, num_classes, layer_idx=-1)
if num_classes > 0
else nn.Identity()
)
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
self.apply(self._init_weights)
for bly in self.layers:
bly._init_respostnorm()
def _init_weights(self, m):
if isinstance(m, Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@classmethod
def from_config(cls, cfg):
return {
"img_size": cfg.img_size,
"patch_size": cfg.patch_size,
"in_chans": cfg.in_chans,
"num_classes": cfg.num_classes,
"embed_dim": cfg.embed_dim,
"depths": cfg.depths,
"num_heads": cfg.num_heads,
"window_size": cfg.window_size,
"mlp_ratio": cfg.mlp_ratio,
"qkv_bias": cfg.qkv_bias,
"drop_rate": cfg.drop_rate,
"drop_path_rate": cfg.drop_path_rate,
"ape": cfg.ape,
"patch_norm": cfg.patch_norm,
"pretrained_window_sizes": cfg.pretrained_window_sizes,
"loss_func": cfg.loss_func,
}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = flow.flatten(x, 1)
return x
def forward(self, images, labels=None):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x = self.forward_features(images)
x = self.head(x)
if labels is not None and self.training:
losses = self.loss_func(x, labels)
return {"losses": losses}
else:
return {"prediction_scores": x}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.patch_embed, "config"):
# Old API in OneFlow 0.8
model.patch_embed.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, PatchMerging):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
model.patch_embed.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), PatchMerging):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
if hasattr(module_block, "origin"):
# Old API in OneFlow 0.8
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).activation_checkpointing = True
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn as nn
from libai.config import configurable
from libai.layers import (
Embedding,
LayerNorm,
LMLogits,
ParallelCrossEntropyLoss,
TransformerLayer,
VocabEmbedding,
)
from libai.layers.attention import AttnMaskType
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist
class ExtendedMask(flow.nn.Module):
def forward(self, attention_mask):
return attention_mask.unsqueeze(1)
class T5Embedding(flow.nn.Module):
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method=flow.nn.init.xavier_normal_,
amp_enabled=False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.word_embeddings = VocabEmbedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.position_embeddings = Embedding(
num_embeddings=max_sequence_length,
embedding_dim=hidden_size,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.position_ids = flow.arange(
max_sequence_length,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
).unsqueeze(0)
self.embedding_dropout = flow.nn.Dropout(embedding_dropout_prob)
def forward(self, input_ids, past_length=0):
seq_length = input_ids.size()[1]
position_ids = self.position_ids[:, past_length : past_length + seq_length]
position_ids = position_ids.expand_as(input_ids).to_global(sbp=input_ids.sbp)
word_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
embeddings = self.embedding_dropout(embeddings)
return embeddings
class T5Model(flow.nn.Module):
"""T5 Model that outputs logits.
Args:
vocab_size (int): The size of vocabulary file.
hidden_size (int): The size of hidden states.
hidden_layers (int): The number of ``TransformerLayer`` in the encoder and decoder.
num_attention_heads (int):
The number of attention heads for each attention layer of ``TransformerLayer``.
intermediate_size (int):
The size of intermediate layer in feed-forward network for each ``TransformerLayer``.
embedding_dropout_prob (float): The dropout ratio for the output of T5Embedding Layer.
hidden_dropout_prob (float): The dropout ratio for the output for each ``TransformerLayer``.
attention_probs_dropout_prob (float):
The dropout ratio for the output of each attention layer in ``TransformerLayer``.
max_position_embeddings (int):
Max sequence length of input, defines the shape of Position Embeddings
in ``T5Emebedding``.
initializer_range (float, optional):
Sigma of the normal distribution in the initialization method. Defaults to 0.02.
layernorm_eps (float, optional): The epsilon of LayerNorm layer. Defaults to 1e-12.
bias_gelu_fusion (bool, optional):
Whether or not to fuse the computing of bias and gelu. Defaults to ``False``.
bias_dropout_fusion (bool, optional):
Whether or not to fuse the computing of dropout and bias. Defaults to ``False``.
scale_mask_softmax_fusion (bool, optional):
Whether to fuse the computing of mask and softmax in attention layers.
Defaults to ``False``.
apply_query_key_layer_scaling (bool, optional):
Whether or not to use layer index related scaling in computing attention scores.
If ``True``, the scaling factor equals to sqrt(d) * (layer_index + 1).
Defaults to ``True``.
apply_residual_post_layernorm (bool, optional):
If set ``True``, use original BERT residual connection ordering otherwise use Megatron
BERT residual connection which is more stable when scaling model size introduced in
https://arxiv.org/pdf/1909.08053.pdf.
Default: ``False``.
amp_enabled (bool, optional):
Whether or not to set fp16 for embedding weight in T5 model. Defaults to ``False``.
"""
@configurable
def __init__(
self,
vocab_size,
hidden_size,
hidden_layers,
num_attention_heads,
intermediate_size,
embedding_dropout_prob,
hidden_dropout_prob,
attention_probs_dropout_prob,
max_position_embeddings,
initializer_range=0.02,
layernorm_eps=1e-12,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
apply_residual_post_layernorm=False,
amp_enabled=False,
) -> None:
super().__init__()
init_method = init_method_normal(initializer_range)
scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers)
self.embedding = T5Embedding(
hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_position_embeddings,
embedding_dropout_prob=embedding_dropout_prob,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.extended_attn_mask = ExtendedMask()
encoder_layers = flow.nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=num_attention_heads,
is_decoder=False,
attention_dropout_prob=attention_probs_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
apply_residual_post_layernorm=apply_residual_post_layernorm,
attn_mask_type=AttnMaskType.padding,
layer_idx=i,
)
for i in range(hidden_layers)
]
)
encoder_final_layernorm = LayerNorm(
(hidden_size,),
eps=layernorm_eps,
layer_idx=hidden_layers - 1,
)
self.encoder = flow.nn.Sequential()
self.encoder.add_module("layers", encoder_layers)
self.encoder.add_module("final_layernorm", encoder_final_layernorm)
decoder_layers = flow.nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=num_attention_heads,
is_decoder=True,
attention_dropout_prob=attention_probs_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attn_mask_type=AttnMaskType.padding,
layer_idx=i,
)
for i in range(hidden_layers, 2 * hidden_layers)
]
)
decoder_final_layernorm = LayerNorm(
(hidden_size,),
eps=layernorm_eps,
layer_idx=2 * hidden_layers - 1,
)
self.decoder = flow.nn.Sequential()
self.decoder.add_module("layers", decoder_layers)
self.decoder.add_module("final_layernorm", decoder_final_layernorm)
self.past_key_values = [None] * len(self.decoder.layers)
self.encoder_states = None
self.past_length = 0
self.lm_head = LMLogits(vocab_size, bias=True)
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"hidden_layers": cfg.hidden_layers,
"num_attention_heads": cfg.num_attention_heads,
"intermediate_size": cfg.intermediate_size,
"embedding_dropout_prob": cfg.embedding_dropout_prob,
"hidden_dropout_prob": cfg.hidden_dropout_prob,
"attention_probs_dropout_prob": cfg.attention_probs_dropout_prob,
"max_position_embeddings": cfg.max_position_embeddings,
"initializer_range": cfg.initializer_range,
"layernorm_eps": cfg.layernorm_eps,
"bias_gelu_fusion": cfg.bias_gelu_fusion,
"bias_dropout_fusion": cfg.bias_dropout_fusion,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
"apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling,
"apply_residual_post_layernorm": cfg.apply_residual_post_layernorm,
"amp_enabled": cfg.amp_enabled,
}
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
use_cache=False,
):
"""
Args:
encoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for encoder.
decoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for decoder.
encoder_attn_mask (flow.BoolTensor):
Mask for encoder to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on subsequent token indices.
Mask values have the same meaning as encoder_attn_mask.
encoder_decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on encoder padded token indices.
Mask values have the same meaning as encoder_attn_mask.
use_cache (bool, optional):
It will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
Returns:
flow.Tensor: logits
"""
encoder_input_ids = encoder_input_ids.to_global(placement=dist.get_layer_placement(0))
decoder_input_ids = decoder_input_ids.to_global(placement=dist.get_layer_placement(0))
encoder_attn_mask = encoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
decoder_attn_mask = decoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
encoder_decoder_attn_mask = encoder_decoder_attn_mask.to_global(
placement=dist.get_layer_placement(0)
)
if use_cache and self.encoder_states is not None:
encoder_states = self.encoder_states
else:
self.set_cache(encoder_states=None, past_key_values=None)
encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask)
enc_embedding_output = self.embedding(encoder_input_ids)
enc_hidden_states = enc_embedding_output
for layer in self.encoder.layers:
enc_hidden_states = layer(enc_hidden_states, encoder_attn_mask)
encoder_states = self.encoder.final_layernorm(enc_hidden_states)
decoder_attn_mask = self.extended_attn_mask(decoder_attn_mask)
encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask)
dec_embedding_output = self.embedding(decoder_input_ids, self.past_length)
dec_hidden_states = dec_embedding_output
if use_cache:
presents = []
for layer, past_key_value in zip(self.decoder.layers, self.past_key_values):
dec_hidden_states = layer(
dec_hidden_states,
decoder_attn_mask,
encoder_states,
encoder_decoder_attn_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
if use_cache:
dec_hidden_states, present = dec_hidden_states
presents.append(present)
if use_cache:
self.set_cache(encoder_states, past_key_values=presents)
decoder_states = self.decoder.final_layernorm(dec_hidden_states)
logits = self.lm_head(decoder_states, self.embedding.word_embeddings.weight)
return logits
def set_cache(self, encoder_states, past_key_values):
self.encoder_states = encoder_states
self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2]
if past_key_values is None:
past_key_values = [None] * len(self.decoder.layers)
assert len(past_key_values) == len(self.decoder.layers), (
f"past_key_values's length {len(past_key_values)} doesn't match "
f"decoder num_layers' length {self.decoder.layers}"
)
self.past_key_values = past_key_values
class T5Loss(flow.nn.Module):
def __init__(self) -> None:
super().__init__()
self.lm_loss = ParallelCrossEntropyLoss()
def forward(self, logits, lm_labels, loss_mask):
lm_loss = self.lm_loss(logits, lm_labels)
loss_mask = loss_mask.to_global(placement=lm_loss.placement)
loss_mask = loss_mask.float()
denominator = loss_mask.sum().to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
)
lm_loss = flow._C.amp_white_identity(lm_loss)
lm_loss = flow._C.amp_black_identity(lm_loss)
masked_lm_loss = flow.sum(lm_loss.view(-1) * loss_mask.view(-1)) / denominator
masked_lm_loss = masked_lm_loss.to_global(
sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast])
)
return {"masked_lm_loss": masked_lm_loss}
class T5ForPreTraining(flow.nn.Module):
"""
T5 Model with classification head on top.
"""
def __init__(self, cfg) -> None:
super().__init__()
self.t5_model = T5Model(cfg)
self.loss_func = T5Loss()
def set_cache(self, encoder_states, past_key_values):
self.t5_model.set_cache(encoder_states, past_key_values)
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
lm_labels=None,
loss_mask=None,
use_cache=False,
):
"""
Args:
encoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for encoder.
decoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for decoder.
encoder_attn_mask (flow.BoolTensor):
Mask for encoder to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on subsequent token indices.
Mask values have the same meaning as encoder_attn_mask.
encoder_decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on encoder padded token indices.
Mask values have the same meaning as encoder_attn_mask.
lm_labels (flow.LongTensor, optional): Labels for computing the masked
language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`.
None for evaluating.
loss_mask (flow.BoolTensor, optional):
Mask to avoid performing loss computing on ignored tokens.
Tokens with indices set to `-1` are ignored (masked), the loss is only computed
for the tokens with labels in `[0, ..., config.vocab_size]`.
None for evaluating.
use_cache (bool, optional):
It will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"masked_lm_loss": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
logits = self.t5_model(
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
use_cache=use_cache,
)
if lm_labels is not None:
lm_loss = self.loss_func(logits, lm_labels, loss_mask)
return lm_loss
else:
return {
"prediction_scores": logits,
}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.t5_model.encoder.final_layernorm, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
if isinstance(module_block.origin, T5Embedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, ExtendedMask):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, LMLogits):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
elif isinstance(module_block.origin, T5Loss):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.t5_model.encoder.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(model.t5_model.encoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.encoder.final_layernorm.layer_idx),
)
model.t5_model.decoder.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx),
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), T5Embedding):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), ExtendedMask):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), LMLogits):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
elif isinstance(module_block.to(nn.Module), T5Loss):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.t5_model.encoder.final_layernorm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(model.t5_model.encoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.encoder.final_layernorm.layer_idx),
)
model.t5_model.decoder.final_layernorm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx),
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .graph_base import GraphBase
from .weight_init import init_method_normal, scaled_init_method_normal
from .model_utils.base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
from .model_utils.bert_loader import BertLoaderHuggerFace, BertLoaderLiBai
from .model_utils.roberta_loader import RobertaLoaderHuggerFace, RobertaLoaderLiBai
from .model_utils.gpt_loader import GPT2LoaderHuggerFace, GPT2LoaderLiBai
from .model_utils.swin_loader import SwinLoaderHuggerFace, SwinLoaderLiBai
from .model_utils.swinv2_loader import SwinV2LoaderHuggerFace, SwinV2LoaderLiBai
from .model_utils.vit_loader import ViTLoaderHuggerFace, ViTLoaderLiBai
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import oneflow as flow
from oneflow import nn
from libai.layers import TransformerLayer
from libai.utils import distributed as dist
logger = logging.getLogger(__name__)
class GraphBase(nn.Graph):
def __init__(
self,
model: nn.Module,
optimizer: flow.optim.Optimizer = None,
lr_scheduler: flow.optim.lr_scheduler = None,
fp16=False,
activation_checkpoint=False,
grad_acc_steps=1,
zero_optim=False,
zero_stage=0,
is_train=True,
auto_parallel_conf=None,
):
super().__init__()
self.model = model
self.is_train = is_train
if is_train:
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
if fp16:
self.config.enable_amp(True)
grad_scaler = flow.amp.GradScaler(
init_scale=65536.0 * dist.get_data_parallel_size(),
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
)
self.set_grad_scaler(grad_scaler)
if grad_acc_steps > 1:
self.config.set_gradient_accumulation_steps(grad_acc_steps)
if activation_checkpoint:
self.set_activation_checkpoint()
if zero_optim:
self.config.enable_zero(True, stage=zero_stage)
self.set_pipeline_stage_id()
self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_model_update_ops(True)
self.config.allow_fuse_cast_scale(True)
# Enable cuda stream for computation and communication as the same stream.
# This will reduce memory when using model parallelism.
dist_util = dist.get_dist_util()
if dist_util.is_tensor_model_parallel() or dist_util.is_pipeline_model_parallel():
flow.boxing.nccl.enable_use_compute_stream(True)
# auto_parallel
if auto_parallel_conf is not None and auto_parallel_conf.enabled:
try:
self.config.enable_auto_parallel(True)
self.config.enable_auto_parallel_ignore_user_sbp_config(
auto_parallel_conf.enable_auto_parallel_ignore_user_sbp_config
)
self.config.set_auto_parallel_computation_cost_ratio(0.05)
self.config.set_auto_parallel_wait_time(1.65e4)
self.config.enable_auto_parallel_trunk_algo(auto_parallel_conf.trunk_algo)
self.config.enable_auto_parallel_sbp_collector(auto_parallel_conf.sbp_collector)
except RuntimeWarning:
import warnings
warnings.warn(
"The version of oneflow don't support auto_parallel.\n"
"Please reinstall the oneflow nightly:\n"
"python3 -m pip install --pre oneflow -f https://staging.oneflow.info/branch/master/[PLATFORM]" # noqa
)
def build(self, **kwargs):
if self.is_train:
logger.info(
"Start compling the train graph which may take some time. "
"Please wait for a moment ..."
)
loss_dict = self.model(**kwargs)
losses = sum(v for k, v in loss_dict.items() if "loss" in k)
losses.backward()
return loss_dict
else:
logger.info(
"Start compling the eval graph which may take some time. "
"Please wait for a moment ..."
)
return self.model(**kwargs)
def set_activation_checkpoint(self):
if hasattr(self.model, "origin"):
if hasattr(type(self.model.origin), "set_activation_checkpoint"):
type(self.model.origin).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
else:
if hasattr(type(self.model.to(nn.Module)), "set_activation_checkpoint"):
type(self.model.to(nn.Module)).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).activation_checkpointing = True
def set_pipeline_stage_id(self):
if hasattr(self.model, "origin"):
if hasattr(type(self.model.origin), "set_pipeline_stage_id"):
type(self.model.origin).set_pipeline_stage_id(self.model)
else:
if hasattr(type(self.model.to(nn.Module)), "set_pipeline_stage_id"):
type(self.model.to(nn.Module)).set_pipeline_stage_id(self.model)
## Introduction
Here are the Weight Loaders currently supported in LiBai. You can use them to load the models in LiBai and the models stored on the huggingface.
## Weight Loader On LiBai
- [BERT Loader](./bert_loader.py)
- [RoBERTa Loader](./roberta_loader.py)
- [GPT2 Loader](./gpt_loader.py)
- [MT5 Loader](../../../../projects/MT5/utils/mt5_loader.py)
- [SWIN Loader](./swin_loader.py)
- [SWIN2 Loader](./swinv2_loader.py)
- [VIT Loader](./vit_loader.py)
## How To Use
We can easily load pretrained BERT as following:
```python
import libai
from libai.models.utils import BertLoaderHuggerFace, BertLoaderLiBai
from configs.common.models.bert import cfg
# load huggingface weight
loader = BertLoaderHuggerFace(
model=libai.models.BertModel,
libai_cfg=cfg,
pretrained_model_path="path/to/huggingface_pretrained_model_directory",
hidden_dropout_prob=0,
apply_residual_post_layernorm=True
)
bert = loader.load()
# load libai weight
loader = BertLoaderLiBai(
model=libai.models.BertModel,
libai_cfg=cfg,
pretrained_model_path='path/to/libai_pretrained_model_directory',
hidden_dropout_prob=0,
)
bert = loader.load()
```
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import logging
import os
import omegaconf
import oneflow as flow
from termcolor import colored
import libai.utils.distributed as dist
from libai.config import LazyCall
from libai.models.build import build_model
logger = logging.getLogger(__name__)
WEIGHTS_NAME_PT = "pytorch_model.bin"
CONFIG_NAME = "config.json"
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
"""load state dict into model
Args:
model_to_load (nn.Module): Model to be loaded.
state_dict (OrderedDict): State dict of pretrained model.
start_prefix (str): Start prefix.
Returns:
list: error message about loading.
"""
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model_to_load, prefix=start_prefix)
return error_msgs
class ModelLoader(object):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
"""Class used to load the [`transformers`](https://huggingface.co/models) pretrained model
or `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
self.model = model
self.libai_cfg = libai_cfg
self.pretrained_model_path = pretrained_model_path
self.kwargs = kwargs
self.output_loading_info = kwargs.pop("output_loading_info", False)
def _state_dict_to_global(self, flow_state_dict=None, mode="libai"):
"""Tensor in OneFlow state dict to global according to model's sbp and placement.
Args:
flow_state_dict (OrderedDict): State dict of OneFlow's pretrained model.
"""
assert mode in ["libai", "pytorch"], f"not support for mode {mode}"
if mode == "libai" or dist.is_main_process():
prefix = self.base_model_prefix_2
# Checkpoint
has_prefix_module = any(
s.startswith(self.base_model_prefix_2) for s in flow_state_dict.keys()
)
# Module
expects_prefix_module = any(
s.startswith(prefix) for s in self.model.state_dict().keys()
)
start_prefix = "" if has_prefix_module else prefix + "."
loaded_keys = [start_prefix + key for key in flow_state_dict.keys()]
else:
prefix, has_prefix_module, expects_prefix_module, loaded_keys = [None] * 4
flow_state_dict = collections.OrderedDict()
prefix = dist.broadcast_py_object(prefix, src=0)
has_prefix_module = dist.broadcast_py_object(has_prefix_module, src=0)
expects_prefix_module = dist.broadcast_py_object(expects_prefix_module, src=0)
loaded_keys = dist.broadcast_py_object(loaded_keys, src=0)
# to global
for key, value in self.model.state_dict().items():
if not expects_prefix_module:
key = prefix + "." + key
if key in loaded_keys:
if not has_prefix_module:
key = ".".join(key.split(".")[1:])
if mode == "pytorch":
flow_state_dict[key] = flow.to_global(
flow_state_dict[key] if dist.is_main_process() else flow.Tensor(None),
sbp=flow.sbp.broadcast,
placement=flow.placement("cpu", ranks=[0]),
)
flow_state_dict[key] = flow.to_global(
flow_state_dict[key],
sbp=value.sbp,
placement=flow.placement("cpu", ranks=list(value.placement.ranks)),
)
return flow_state_dict
def _load_pretrained_model(
self,
model,
state_dict,
pretrained_model_path,
ignore_mismatched_sizes=False,
):
"""Load pretrained model.
Args:
model (libai.models): The model to be loaded.
state_dict (OrderedDict): state dict.
loaded_keys (list): keys of state dict.
pretrained_model_path (str): pretrained modelE path.
ignore_mismatched_sizes (bool):
Whether or not to raise an error if some of the weights
from the checkpoint do not have the same size as the
weights of the model, defaults to `False`.
"""
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = self.base_model_prefix_2
loaded_keys = state_dict.keys()
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [
".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys
]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
start_prefix = ""
model_to_load = model
if (
len(self.base_model_prefix_2) > 0
and not hasattr(model, self.base_model_prefix_2)
and has_prefix_module
):
start_prefix = self.base_model_prefix_2 + "."
if (
len(self.base_model_prefix_2) > 0
and hasattr(model, self.base_model_prefix_2)
and not has_prefix_module
):
model_to_load = getattr(model, self.base_model_prefix_2)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError("The state dict of the model you are loading is corrupted.")
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(
checkpoint_key,
state_dict[checkpoint_key].shape,
model_state_dict[model_key].shape,
)
)
del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if dist.get_local_rank() == 0:
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(
f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_path} "
"were not used when "
f"initializing {model.__class__.__name__}:\n {unexpected_keys}\n"
)
else:
logger.info(
f"All model checkpoint weights were used when initializing "
f"{model.__class__.__name__}.\n"
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized "
f"from the model checkpoint at {pretrained_model_path}:\n "
f"{missing_keys} \n"
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized "
f"from the model checkpoint at {pretrained_model_path}.\n"
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2}"
"in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized"
f"from the model checkpoint at {pretrained_model_path} "
f"and are newly initialized because the shapes did not"
f"match:\n{mismatched_warning}\n"
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
class ModelLoaderLiBai(ModelLoader):
"""Class used to load `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = None # prefix in LiBai
def _load_flow_state_dict(self, state_dict_file):
# load oneflow_model
state_dict = flow.load(state_dict_file, global_src_rank=0)
return state_dict
def load(self):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from libai.config.configs.common.models.bert import cfg
>>> from model_utils import BertLoaderLiBai
>>> loder = BertLoaderLiBai(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loder.load()
"""
if dist.is_main_process():
assert os.path.isdir(
self.pretrained_model_path
), f"{self.pretrained_model_path} must be a directory"
flow_state_dict = self._load_flow_state_dict(self.pretrained_model_path)
# Instance model
if isinstance(self.model, omegaconf.dictconfig.DictConfig):
self.model.cfg = self.libai_cfg
self.model = build_model(self.model)
else:
self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))
# State_dict to global
self._state_dict_to_global(flow_state_dict, mode="libai")
# Load
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
error_msgs,
) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
if self.output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
class ModelLoaderHuggerFace(ModelLoader):
"""Class used to load the [`transformers`](https://huggingface.co/models)
pretrained model.
"""
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_1 = None # prefix in Transformers
self.base_model_prefix_2 = None # prefix in LiBai
self.origin_libai_cfg = copy.deepcopy(self.libai_cfg)
self.changed_keys = set() # Store the changed configuration
def _convert_tensor(self, tensor):
"""Convert PyTorch tensor to OneFlow tensor.
Args:
tensor (torch.Tensor): The source tensor.
Returns:
flow.Tensor: The target tensor.
"""
tensor = tensor.float()
return flow.Tensor(tensor.detach().cpu().numpy())
def _convert_tensors(self, torch_state_dict):
for k, v in torch_state_dict.items():
torch_state_dict[k] = self._convert_tensor(v)
return torch_state_dict
def _fix_key(self, state_dict):
"""Fix the key in state dict: Convert "gamma" to "weight" and "beta" to "bias".
Args:
state_dict (OrderedDict): state dict of pretrained model.
Returns:
OrderedDict: State dict after fix key.
"""
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return state_dict
def _fix_qkv_ordering(
self, qkv, head_size, num_heads, hidden_size=None, checkpoint_version=0.0
):
# TODO(xzp): Different versions checkpoint
hidden_size = (head_size * num_heads) if hidden_size is None else hidden_size
num_of_qkv = qkv.shape[0] // (head_size * num_heads)
mode = "weight" if qkv.ndim > 1 else "bias"
if mode == "weight":
qkv = qkv.view([num_of_qkv, num_heads, head_size, hidden_size])
qkv = (
qkv.permute(1, 0, 2, 3)
.contiguous()
.view(num_of_qkv * head_size * num_heads, hidden_size)
)
elif mode == "bias":
qkv = qkv.view(num_of_qkv, num_heads, head_size)
qkv = qkv.permute(1, 0, 2).contiguous().view(-1)
return qkv
def _convert_state_dict(self, flow_state_dict, cfg):
"""A function used to convert the checkpoint file of Huggingface to LiBai.
Args:
torch_state_dict (OrderedDict): torch state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
raise NotImplementedError("_convert_state_dict not implemented")
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
raise NotImplementedError("_load_config_from_json not implemented")
def _load_torch_state_dict(self, state_dict_file):
try:
import torch
except ImportError:
raise ImportError("Load torch state dict need torch.")
# load pytorch_model.bin
state_dict = torch.load(state_dict_file, map_location="cpu")
return state_dict
def _update_cfg(self, keys_libai, value_target):
"""Update the libai_cfg according to target_cfg.
Args:
keys_libai (str): The key of libai_cfg.
value_target (int | float): The value of target_cfg.
"""
if keys_libai not in self.libai_cfg.keys():
return
if self.libai_cfg[keys_libai] != value_target:
self.libai_cfg[keys_libai] = value_target
def _update_cfg_log(self):
if dist.get_local_rank() == 0:
for key in sorted(self.libai_cfg):
if self.origin_libai_cfg[key] == self.libai_cfg[key]:
continue
self.changed_keys.add(key)
temp_key = colored(key, "yellow")
logger.info(
f"changed libai model cfg {temp_key} : "
f"{self.origin_libai_cfg[key]} -> {self.libai_cfg[key]} "
)
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)
if dist.get_pipeline_parallel_size() > 1:
logger.warning(
colored(
"If you use pipeline parallel, please "
"confirm the setting of `train.dist.pipeline_num_layers` \n",
"red",
)
)
def load(self):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from configs.common.models.bert import cfg
>>> from libai.models.utils import BertLoaderHugger
>>> loader = BertLoaderHugger(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loader.load()
"""
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
if os.path.isfile(os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)):
model_file = os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME_PT} found"
f"in directory {self.pretrained_model_path}."
)
# config file
if os.path.isfile(os.path.join(self.pretrained_model_path, CONFIG_NAME)):
config_file = os.path.join(self.pretrained_model_path, CONFIG_NAME)
# Load config and update config.
self._load_config_from_json(config_file)
else:
import warnings
warnings.warn(
f"Error no file named {CONFIG_NAME} found in directory"
f"{self.pretrained_model_path}",
RuntimeWarning,
)
else:
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")
logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_file)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
flow_state_dict = self._convert_state_dict(torch_state_dict, self.libai_cfg)
else:
flow_state_dict = None
self.libai_cfg = dist.broadcast_py_object(self.libai_cfg, src=0)
# Instance model
logger.info("building LiBai model...")
if isinstance(self.model, omegaconf.dictconfig.DictConfig):
self.model.cfg = self.libai_cfg
self.model = build_model(self.model)
else:
self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))
# State_dict to global
logger.info("transfering state_dict local to global...")
flow_state_dict = self._state_dict_to_global(flow_state_dict, mode="pytorch")
logger.info("loading model weights into LiBai...")
# Load
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
error_msgs,
) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
if self.output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment