Commit 73557d95 authored by yuguo960516's avatar yuguo960516
Browse files

glm

parents
Pipeline #148 failed with stages
in 0 seconds
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
from libai.layers import Linear
from libai.utils import distributed as dist
class LMLogits(nn.Module):
def __init__(self, vocab_size, hidden_size=None, bias=False, model_type="t5", layer_idx=-1):
super().__init__()
self.model_type = model_type
if model_type == "t5":
self.bias = (
nn.Parameter(
flow.zeros(
(vocab_size,),
dtype=flow.float32,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(0)]),
)
)
if bias
else None
)
elif model_type == "mt5":
self.linear = Linear(hidden_size, vocab_size, bias=False, layer_idx=layer_idx)
def forward(self, input, word_embeddings=None):
if self.model_type == "t5":
w = word_embeddings.to_global(placement=input.placement)
input = input.to_global(grad_sbp=input.sbp)
logits = flow._C.matmul(input, w, transpose_b=True)
if self.bias is not None:
logits = logits + self.bias
else:
logits = self.linear(input)
return logits
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from libai.layers import ParallelCrossEntropyLoss
from libai.utils import distributed as dist
class MT5Loss(flow.nn.Module):
def __init__(self) -> None:
super().__init__()
self.lm_loss = ParallelCrossEntropyLoss()
def forward(self, logits, lm_labels, loss_mask):
lm_labels = lm_labels.to_global(placement=logits.placement)
lm_loss = self.lm_loss(logits, lm_labels)
loss_mask = loss_mask.to_global(placement=lm_loss.placement)
loss_mask = loss_mask.float()
denominator = loss_mask.sum().to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
)
lm_loss = flow._C.amp_white_identity(lm_loss)
lm_loss = flow._C.amp_black_identity(lm_loss)
masked_lm_loss = flow.sum(lm_loss.view(-1) * loss_mask.view(-1)) / denominator
masked_lm_loss = masked_lm_loss.to_global(
sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast])
)
if self.training:
# token throughput
done_tokens = (
flow.zeros(
1,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=lm_labels.placement,
)
+ logits.shape[0] * logits.shape[1]
)
# correct token
correct_tokens = flow.sum(
(
logits.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=lm_labels.placement,
)
.argmax(dim=-1)
.eq(lm_labels)
).float()
)
return {
"mlm_loss": masked_lm_loss,
"done_tokens": done_tokens,
"correct_tokens": correct_tokens,
"denominator": denominator,
}
else:
return {
"mlm_loss": masked_lm_loss,
}
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from libai.utils import distributed as dist
class ExtendedMask(flow.nn.Module):
def forward(self, x, input_tensor=None, is_decoder=False):
if x.dim() == 3:
extended_mask = x[:, None, :, :]
elif x.dim() == 2:
if is_decoder:
extended_mask = self.create_extended_mask_for_decoder(x, input_tensor)
else:
extended_mask = x[:, None, None, :]
return extended_mask
def create_extended_mask_for_decoder(self, x, input_tensor):
batch_size, seq_len = input_tensor.size()
seq_ids = flow.arange(
seq_len,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=x.placement,
)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_len, 1) <= seq_ids[None, :, None]
)
causal_mask = causal_mask.to(x.dtype)
causal_mask = causal_mask.to_global(sbp=x.sbp)
if causal_mask.shape[1] < x.shape[1]:
prefix_seq_len = x.shape[1] - causal_mask.shape[1]
ones = flow.ones(
(batch_size, seq_len, prefix_seq_len),
dtype=causal_mask.dtype,
sbp=causal_mask.sbp,
placement=causal_mask.placement,
)
causal_mask = flow.cat(
[
ones,
causal_mask,
],
dim=-1,
)
extended_mask = causal_mask[:, None, :, :] * x[:, None, None, :]
return extended_mask
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
from libai.layers import Linear, build_activation
class T5MLP(nn.Module):
def __init__(
self,
hidden_size,
ffn_hidden_size,
output_dropout_prob=0.0,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
*,
layer_idx=0,
):
super().__init__()
self.output_dropout_prob = output_dropout_prob
if output_layer_init_method is None:
output_layer_init_method = init_method
self.dense_h_to_4h = Linear(
hidden_size,
ffn_hidden_size,
bias=False,
parallel="col",
skip_bias_add=False,
init_method=init_method,
layer_idx=layer_idx,
)
self.activation_func = build_activation("relu")
self.dense_4h_to_h = Linear(
ffn_hidden_size,
hidden_size,
bias=False,
parallel="row",
skip_bias_add=False,
init_method=output_layer_init_method,
layer_idx=layer_idx,
)
self.dropout = nn.Dropout(self.output_dropout_prob)
def forward(self, hidden_states):
intermediate = self.dense_h_to_4h(hidden_states)
intermediate = self.activation_func(intermediate)
output = self.dense_4h_to_h(intermediate)
output = self.dropout(output)
return output
class MT5MLP(nn.Module):
def __init__(
self,
hidden_size,
ffn_hidden_size,
output_dropout_prob=0.0,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
*,
layer_idx=0,
):
super().__init__()
self.output_dropout_prob = output_dropout_prob
if output_layer_init_method is None:
output_layer_init_method = init_method
self.wi_0 = Linear(
hidden_size,
ffn_hidden_size,
bias=False,
parallel="col",
skip_bias_add=False,
init_method=init_method,
layer_idx=layer_idx,
)
self.wi_1 = Linear(
hidden_size,
ffn_hidden_size,
bias=False,
parallel="col",
skip_bias_add=False,
init_method=init_method,
layer_idx=layer_idx,
)
self.wo = Linear(
ffn_hidden_size,
hidden_size,
bias=False,
parallel="row",
skip_bias_add=False,
init_method=output_layer_init_method,
layer_idx=layer_idx,
)
self.dropout = nn.Dropout(self.output_dropout_prob)
def forward(self, hidden_states):
wi_0_out = self.wi_0(hidden_states)
hidden_linear = self.wi_1(hidden_states)
hidden_states = flow._C.fused_fast_gelu_mul(wi_0_out, hidden_linear)
output = self.wo(hidden_states)
output = self.dropout(output)
return output
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow.nn as nn
from libai.layers.droppath import DropPath
from libai.layers.layer_norm import RMSLayerNorm as LayerNorm
from libai.utils import distributed as dist
from projects.MT5.layers.attention_layer import MultiheadAttention
from projects.MT5.layers.mlp_layer import MT5MLP, T5MLP
class TransformerLayer(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [bsz, seq_length, hidden size] and returns an
output of the same size.
The input and output has same sbp sign, (S(0), B).
Arguments:
hidden_size: size of hidden state.
ffn_hidden_size: size of feed forword neural network.
num_attention_heads: number of attention heads.
is_decoder: used to specify whether this is transformer encoder layer or transformer
decoder layer. Default: ``False``.
attention_dropout_prob: dropout probability of attention weights.
output_dropout_prob: dropout probability of output.
layernorm_epsilon: epsilon used in layernorm layer. Default: `1e-5`.
init_method: method to initialize the input layer weights.
output_layer_init_method: method to initialize the output layer weights.
If None, use `init_method`.
layer_idx: the layer index, which determines the placement.
"""
def __init__(
self,
hidden_size,
ffn_hidden_size,
num_attention_heads,
head_size,
relative_attention_num_buckets,
is_decoder=False,
attention_dropout_prob=0.0,
output_dropout_prob=0.0,
drop_path_prob=0.0,
layernorm_epsilon=1e-5,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
padding_idx=None,
*,
layer_idx=0,
model_type="t5",
has_relative_attention_bias=False
):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.num_attention_heads = num_attention_heads
self.head_size = head_size
self.attention_dropout_prob = attention_dropout_prob
self.output_dropout_prob = output_dropout_prob
self.layernorm_epsilon = layernorm_epsilon
self.layer_idx = layer_idx
self.is_decoder = is_decoder
self.init_method = init_method
if output_layer_init_method is None:
output_layer_init_method = init_method
self.output_layer_init_method = output_layer_init_method
self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0.0 else nn.Identity()
self.input_layernorm = LayerNorm(
self.hidden_size, eps=self.layernorm_epsilon, layer_idx=self.layer_idx
)
self.self_attention = self.build_attention(
is_cross_attention=False,
relative_attention_num_buckets=relative_attention_num_buckets,
padding_idx=padding_idx,
has_relative_attention_bias=has_relative_attention_bias,
is_decoder=self.is_decoder,
)
self.post_attention_layernorm = LayerNorm(
self.hidden_size, eps=self.layernorm_epsilon, layer_idx=self.layer_idx
)
if self.is_decoder:
self.cross_attention = self.build_attention(
is_cross_attention=True,
relative_attention_num_buckets=relative_attention_num_buckets,
padding_idx=padding_idx,
is_decoder=self.is_decoder,
)
self.post_cross_attention_layernorm = LayerNorm(
self.hidden_size, eps=self.layernorm_epsilon, layer_idx=self.layer_idx
)
if model_type == "mt5":
self.mlp = MT5MLP(
self.hidden_size,
self.ffn_hidden_size,
self.output_dropout_prob,
self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_idx=self.layer_idx,
)
elif model_type == "t5":
self.mlp = T5MLP(
self.hidden_size,
self.ffn_hidden_size,
self.output_dropout_prob,
self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_idx=self.layer_idx,
)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_states=None,
encoder_attention_mask=None,
past_key_value=None,
use_cache=False,
position_bias=None,
encoder_decoder_position_bias=None,
):
"""
Args:
hidden_states: shape is (batch_size, seq_length, hidden_size),
sbp signature is (S(0), B).
attention_mask: the combination of key padding mask and casual mask of hidden states
with shape (batch_size, 1, seq_length, seq_length) and the sbp
signature is (S(0), B),
encoder_states: encoder output with shape (batch_size, seq_length, hidden_size)
and the sbp signature is (S(0), B), which will be used in cross attention.
encoder_attention_mask: key padding mask of encoder states with shape
(batch_size, 1, seq_length, seq_length) and the sbp signature is (S(0), B).
past_key_value: tuple of key and value, each shape is
(seq_length, bsz, num_heads, head_size), For decoder layer,
the past_key_value contains the states both from self attention
and cross attention.
use_cache: it will be set to `True` when the model is in the inference phase and
used for incremental decoding.
"""
hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))
if attention_mask is not None:
attention_mask = attention_mask.to_global(
placement=dist.get_layer_placement(self.layer_idx)
)
if past_key_value is not None:
if self.is_decoder:
assert len(past_key_value) == 4
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value = past_key_value
cross_attn_past_key_value = None
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
layernorm_output = self.input_layernorm(hidden_states)
attention_output, position_bias = self.self_attention(
layernorm_output,
attention_mask=attention_mask,
past_key_value=self_attn_past_key_value,
position_bias=position_bias,
use_cache=use_cache,
)
attention_output = self.drop_path(attention_output)
if use_cache:
attention_output, presents = attention_output
else:
presents = None
hidden_states = hidden_states + attention_output
layernorm_output = self.post_attention_layernorm(hidden_states)
if self.is_decoder:
if presents is not None:
query_length = presents[0].shape[2]
else:
query_length = None
attention_output, encoder_decoder_position_bias = self.cross_attention(
layernorm_output,
encoder_states,
attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
query_length=query_length,
)
if use_cache:
attention_output, decoder_presents = attention_output
presents = presents + decoder_presents
attention_output = self.drop_path(attention_output)
hidden_states = hidden_states + attention_output
layernorm_output = self.post_cross_attention_layernorm(hidden_states)
mlp_output = self.mlp(layernorm_output)
mlp_output = self.drop_path(mlp_output)
output = hidden_states + mlp_output
if use_cache:
output = (output, presents)
output = (output,) + (position_bias,)
if self.is_decoder:
output = output + (encoder_decoder_position_bias,)
return output
def build_attention(
self,
is_cross_attention=False,
relative_attention_num_buckets=None,
padding_idx=None,
has_relative_attention_bias=False,
is_decoder=False,
):
return MultiheadAttention(
self.hidden_size,
self.num_attention_heads,
head_size=self.head_size,
relative_attention_num_buckets=relative_attention_num_buckets,
is_cross_attention=is_cross_attention,
attention_dropout_prob=self.attention_dropout_prob,
output_dropout_prob=self.output_dropout_prob,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
padding_idx=padding_idx,
layer_idx=self.layer_idx,
has_relative_attention_bias=has_relative_attention_bias,
is_decoder=is_decoder,
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn as nn
from libai.config import configurable
from libai.inference.generator.generation_utils import Generator
from libai.layers import Linear, LMLogits, RMSLayerNorm
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist
from projects.MT5.layers.embed_layer import MT5Embedding
from projects.MT5.layers.loss_layer import MT5Loss
from projects.MT5.layers.mask_layer import ExtendedMask
from projects.MT5.layers.transformer_layer import TransformerLayer
from projects.MT5.utils.mt5_loader import T5LoaderHuggerFace
class MT5Model(flow.nn.Module, Generator):
@configurable
def __init__(
self,
vocab_size,
hidden_size,
hidden_layers,
num_attention_heads,
head_size,
intermediate_size,
embedding_dropout_prob,
hidden_dropout_prob,
attention_probs_dropout_prob,
relative_attention_num_buckets,
padding_idx=None,
initializer_range=0.02,
layernorm_eps=1e-12,
amp_enabled=False,
model_type="mt5",
cfg=None,
) -> None:
super().__init__()
self.cfg = cfg
self.model_type = model_type
init_method = init_method_normal(initializer_range)
scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers)
self.embedding = MT5Embedding(
hidden_size=hidden_size,
vocab_size=vocab_size,
embedding_dropout_prob=embedding_dropout_prob,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.extended_attn_mask = ExtendedMask()
encoder_layers = flow.nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=num_attention_heads,
head_size=head_size,
relative_attention_num_buckets=relative_attention_num_buckets,
is_decoder=False,
attention_dropout_prob=attention_probs_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
padding_idx=padding_idx,
layer_idx=i,
model_type=model_type,
has_relative_attention_bias=bool(i == 0),
)
for i in range(hidden_layers)
]
)
encoder_final_layernorm = RMSLayerNorm(
(hidden_size,),
eps=layernorm_eps,
layer_idx=hidden_layers - 1,
)
self.encoder = flow.nn.Sequential()
self.encoder.add_module("layers", encoder_layers)
self.encoder.add_module("final_layernorm", encoder_final_layernorm)
decoder_layers = flow.nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=num_attention_heads,
head_size=head_size,
relative_attention_num_buckets=relative_attention_num_buckets,
is_decoder=True,
attention_dropout_prob=attention_probs_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
padding_idx=padding_idx,
layer_idx=i,
model_type=model_type,
has_relative_attention_bias=bool(i - hidden_layers == 0),
)
for i in range(hidden_layers, 2 * hidden_layers)
]
)
decoder_final_layernorm = RMSLayerNorm(
(hidden_size,),
eps=layernorm_eps,
layer_idx=2 * hidden_layers - 1,
)
self.decoder = flow.nn.Sequential()
self.decoder.add_module("layers", decoder_layers)
self.decoder.add_module("final_layernorm", decoder_final_layernorm)
self.past_key_values = [None] * len(self.decoder.layers)
self.encoder_states = None
self.past_length = 0
if model_type == "mt5":
self.lm_head = Linear(
hidden_size, vocab_size, bias=False, layer_idx=2 * hidden_layers - 1
)
else:
self.lm_head = LMLogits(vocab_size, bias=False)
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"hidden_layers": cfg.hidden_layers,
"num_attention_heads": cfg.num_attention_heads,
"head_size": cfg.head_size,
"intermediate_size": cfg.intermediate_size,
"embedding_dropout_prob": cfg.embedding_dropout_prob,
"hidden_dropout_prob": cfg.hidden_dropout_prob,
"attention_probs_dropout_prob": cfg.attention_probs_dropout_prob,
"relative_attention_num_buckets": cfg.relative_attention_num_buckets,
"padding_idx": cfg.padding_idx,
"initializer_range": cfg.initializer_range,
"layernorm_eps": cfg.layernorm_eps,
"amp_enabled": cfg.amp_enabled,
"model_type": cfg.model_type,
"cfg": cfg,
}
def forward(
self,
encoder_input_ids=None,
decoder_input_ids=None,
encoder_attn_mask=None,
decoder_attn_mask=None,
encoder_decoder_attn_mask=None,
use_cache=False,
only_encoder=False,
):
encoder_input_ids = (
encoder_input_ids.to_global(placement=dist.get_layer_placement(0))
if encoder_input_ids is not None
else encoder_input_ids
)
decoder_input_ids = (
decoder_input_ids.to_global(placement=dist.get_layer_placement(0))
if decoder_input_ids is not None
else decoder_input_ids
)
encoder_attn_mask = (
encoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
if encoder_attn_mask is not None
else encoder_attn_mask
)
decoder_attn_mask = (
decoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
if decoder_attn_mask is not None
else decoder_attn_mask
)
encoder_decoder_attn_mask = (
encoder_decoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
if encoder_decoder_attn_mask is not None
else encoder_decoder_attn_mask
)
if use_cache and self.encoder_states is not None:
encoder_states = self.encoder_states
else:
position_bias = None
encoder_decoder_position_bias = None
self.set_cache(encoder_states=None, past_key_values=None)
encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask)
enc_embedding_output = self.embedding(encoder_input_ids)
# transpose [batch_size, seq_len, embed_size] to [seq_len, batch_size, embed_size]
enc_hidden_states = enc_embedding_output.transpose(0, 1)
for layer in self.encoder.layers:
enc_hidden_states, position_bias = layer(
enc_hidden_states,
encoder_attn_mask,
position_bias=position_bias,
)
encoder_states = self.encoder.final_layernorm(enc_hidden_states)
if only_encoder:
return encoder_states
decoder_attn_mask = self.extended_attn_mask(
decoder_attn_mask, decoder_input_ids, is_decoder=True
)
encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask)
dec_embedding_output = self.embedding(decoder_input_ids)
# transpose [batch_size, seq_len, embed_size] to [seq_len, batch_size, embed_size]
dec_hidden_states = dec_embedding_output.transpose(0, 1)
if use_cache:
presents = []
position_bias = None
encoder_decoder_position_bias = None
for layer, past_key_value in zip(self.decoder.layers, self.past_key_values):
dec_hidden_states, position_bias, encoder_decoder_position_bias = layer(
dec_hidden_states,
decoder_attn_mask,
encoder_states,
encoder_decoder_attn_mask,
past_key_value=past_key_value,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
use_cache=use_cache,
)
if use_cache:
dec_hidden_states, present = dec_hidden_states
presents.append(present)
if use_cache:
self.set_cache(encoder_states, past_key_values=presents)
decoder_states = self.decoder.final_layernorm(dec_hidden_states)
if self.cfg.tie_word_embeddings:
decoder_states = decoder_states * (self.cfg.hidden_size ** -0.5)
if self.model_type == "mt5":
logits = self.lm_head(decoder_states)
else:
logits = self.lm_head(decoder_states, self.embedding.word_embeddings.weight)
return {"logits": logits}
def set_cache(self, encoder_states, past_key_values):
self.encoder_states = encoder_states
self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2]
if past_key_values is None:
past_key_values = [None] * len(self.decoder.layers)
assert len(past_key_values) == len(self.decoder.layers), (
f"past_key_values's length {len(past_key_values)} doesn't match "
f"decoder num_layers' length {self.decoder.layers}"
)
self.past_key_values = past_key_values
def _reorder_cache(self, beam_idx):
past_key_values = self.past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
beam_idx = beam_idx.to_global(placement=layer_past_state.placement)
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
encoder_attn_mask=None,
encoder_decoder_attn_mask=None,
use_cache=None,
encoder_outputs=None,
):
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
self.past_key_values = past
self.encoder_states = encoder_outputs
decoder_attn_maks = flow.ones(
input_ids.size(),
dtype=flow.bool,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
return {
"decoder_input_ids": input_ids,
"decoder_attn_mask": decoder_attn_maks,
"encoder_attn_mask": encoder_attn_mask,
"encoder_decoder_attn_mask": encoder_decoder_attn_mask,
"use_cache": use_cache,
}
class MT5ForPreTraining(flow.nn.Module):
def __init__(self, cfg) -> None:
super().__init__()
if cfg.pretrained_model_path is not None:
loader = T5LoaderHuggerFace(MT5Model, cfg, cfg.pretrained_model_path)
self.mt5_model = loader.load()
else:
self.mt5_model = MT5Model(cfg)
self.loss_func = MT5Loss()
def set_cache(self, encoder_states, past_key_values):
self.mt5_model.set_cache(encoder_states, past_key_values)
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
lm_labels=None,
loss_mask=None,
use_cache=False,
):
logits = self.mt5_model(
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
use_cache=use_cache,
)["logits"]
# transpose [seq_len, batch_size, vocab_size] to [batch_size, seq_len, vocab_size]
logits = logits.transpose(0, 1)
if lm_labels is not None:
lm_loss = self.loss_func(logits, lm_labels, loss_mask)
return lm_loss
else:
return {
"prediction_scores": logits,
}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.mt5_model.encoder.final_layernorm, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
if isinstance(module_block.origin, MT5Embedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, ExtendedMask):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, MT5Loss):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.mt5_model.encoder.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(model.mt5_model.encoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.mt5_model.encoder.final_layernorm.layer_idx),
)
model.mt5_model.decoder.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(model.mt5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.mt5_model.decoder.final_layernorm.layer_idx),
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), MT5Embedding):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), ExtendedMask):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), MT5Loss):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.mt5_model.encoder.final_layernorm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(model.mt5_model.encoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.mt5_model.encoder.final_layernorm.layer_idx),
)
model.mt5_model.decoder.final_layernorm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(model.mt5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.mt5_model.decoder.final_layernorm.layer_idx),
)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
# Old API in OneFlow 0.8
if hasattr(module_block, "origin"):
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).activation_checkpointing = True
# MT5
Reproduce T5Model and MT5Model with OneFlow, which effect are equivalent to HuggingFace's [T5](https://huggingface.co/docs/transformers/v4.19.4/en/model_doc/t5#overview) and [T5v1.1](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511).
## Introduce
The t5 and mt5 pretraining project can support 3D parallel and [ZERO](https://arxiv.org/abs/2202.10435).
## Training MT5
Training MT5 on 8 GPUs using 3D parallelism and ZERO.
### 1. Prepare your training config file
> set the pretrain parameters in `MT5/configs/mt5_pretrain.py`, such as `vocab_file` and `data_prefix`.
> If you would like to use the t5 model, please set `model_type`="t5".
### 2. Prepare the demo training data
Prepare the demo training data by running:
```bash
# path/to/libai
wget https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/libai/bert_dataset/bert-base-chinese-vocab.txt -P ./data_test/bert_data/
wget https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/libai/bert_dataset/loss_compara_content_sentence.bin -P ./data_test/bert_data/
wget https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/libai/bert_dataset/loss_compara_content_sentence.idx -P ./data_test/bert_data/
```
### 3. Prepare your own training data
If you want to use your own training data, please skip the step2, and refer [Preprocessing Dataset](https://libai.readthedocs.io/en/latest/tutorials/basics/Preprocessing_Dataset.html#).
```bash
IMPL=mmap
KEYS=text
python tools/preprocess_data.py \
--input /path/to/libai/projects/MT5/data/test.json \
--json-keys ${KEYS} \
--vocab-file /path/to/libai/projects/MT5/data/vocab.txt \
--dataset-impl ${IMPL} \
--tokenizer-name BertTokenizer \
--do-lower-case \
--do-chinese-wwm \
--split-sentences \
--output-prefix magic_prompt_${IMPL} \
--workers 4 \
--log-interval 2
```
### 4. Run the following code to start training
```bash
# cd /path/to/libai
bash tools/train.sh projects/MT5/train_net.py projects/MT5/configs/mt5_pretrain.py 8
```
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import sys
import numpy as np
import oneflow as flow
from libai.config import LazyConfig, default_argument_parser, try_get_key
from libai.engine import DefaultTrainer, default_setup
from libai.utils.checkpoint import Checkpointer
from libai.utils.events import JSONWriter, TensorboardXWriter
from projects.MT5.utils.mt5_metrc_printer import MT5MetricPrinter
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
logger = logging.getLogger("libai." + __name__)
class Mt5Trainer(DefaultTrainer):
def __init__(self, cfg):
super().__init__(cfg)
def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
.. code-block:: python
return [
MT5MetricPrinter(self.global_batch_size, self.max_iter),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]
"""
# Assume the default print/log frequency.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
MT5MetricPrinter(self.global_batch_size, self.max_iter, self.cfg.train.log_period),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]
def main(args):
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
seed_for_rank = cfg.train.seed + flow.env.get_rank()
flow.manual_seed(seed_for_rank)
flow.cuda.manual_seed(seed_for_rank)
np.random.seed(seed_for_rank)
random.seed(seed_for_rank)
if args.fast_dev_run:
cfg.train.train_epoch = 0
cfg.train.train_iter = 20
cfg.train.evaluation.eval_period = 10
cfg.train.log_period = 1
if args.eval_only:
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer = Mt5Trainer.build_tokenizer(cfg)
model = Mt5Trainer.build_model(cfg)
Checkpointer(model, save_dir=cfg.train.output_dir).resume_or_load(
cfg.train.load_weight, resume=args.resume
)
if try_get_key(cfg, "train.graph.enabled", default=False):
model = Mt5Trainer.build_graph(cfg, model, is_train=False)
test_loader = Mt5Trainer.build_test_loader(cfg, tokenizer)
if len(test_loader) == 0:
logger.info("No dataset in dataloader.test, please set dataset for dataloader.test")
_ = Mt5Trainer.test(cfg, test_loader, model)
return
trainer = Mt5Trainer(cfg)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
main(args)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from libai.models.utils import ModelLoaderHuggerFace, ModelLoaderLiBai
class T5LoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is T5's prefix in Transformers.
base_model_prefix_2 is T5's prefix in LiBai."""
self.base_model_prefix_1 = "transformer"
self.base_model_prefix_2 = "mt5_model"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
old_keys = list(oneflow_state_dict.keys())
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
head_size = cfg.get("head_size", None)
if head_size is None:
head_size = int(hidden_size / num_heads)
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix1 = self.base_model_prefix_1 + "." if has_prefix else ""
prefix2 = self.base_model_prefix_2 + "." if has_prefix else ""
encoder_decoder_idx = 1 if has_prefix else 0
layer_idx1 = 3 if has_prefix else 2
layer_idx2 = 5 if has_prefix else 4
op_idx = 6 if has_prefix else 5
# Convert T5's Embedding layers.
# NOTE: Transformers' T5 has no position embedding layer.
new_key = prefix2 + "embedding.word_embeddings.weight"
old_keys.remove(prefix1 + "shared.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "shared.weight")
# Convert T5's final_layer_norm
new_key = prefix2 + "encoder.final_layernorm.weight"
old_keys.remove(prefix1 + "encoder.final_layer_norm.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(
prefix1 + "encoder.final_layer_norm.weight"
)
new_key = prefix2 + "decoder.final_layernorm.weight"
old_keys.remove(prefix1 + "decoder.final_layer_norm.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(
prefix1 + "decoder.final_layer_norm.weight"
)
# Convert MT5's lm_head
if cfg.model_type == "mt5" and "lm_head.weight" in oneflow_state_dict:
new_key = prefix2 + "lm_head.weight"
old_keys.remove("lm_head.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop("lm_head.weight")
# NOTE: Each layers has no bias in Transformer's T5.
for key in old_keys:
keys = key.split(".")
if layer_idx1 > len(keys) or layer_idx2 > len(keys):
continue
layer1 = keys[layer_idx1]
layer2 = keys[layer_idx2]
op_name = keys[op_idx]
if keys[op_idx + 1] == "relative_attention_bias" and keys[op_idx] == "SelfAttention":
new_key = (
prefix2
+ keys[encoder_decoder_idx]
+ ".layers.0.self_attention.relative_attention_bias.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert T5's Encoder layers.
if keys[encoder_decoder_idx] == "encoder":
if op_name == "SelfAttention":
new_key = (
prefix2
+ "encoder.layers."
+ layer1
+ ".self_attention.query_key_value.weight"
)
if new_key in oneflow_state_dict.keys():
continue
q_w = ".".join(keys[: op_idx + 1]) + ".q." + "weight"
k_w = ".".join(keys[: op_idx + 1]) + ".k." + "weight"
v_w = ".".join(keys[: op_idx + 1]) + ".v." + "weight"
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads, hidden_size)
oneflow_state_dict[new_key] = qkv_w
o_w = ".".join(keys[: op_idx + 1]) + ".o." + "weight"
new_key = prefix2 + "encoder.layers." + layer1 + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(o_w)
elif op_name == "layer_norm":
if layer2 == "0":
new_key = prefix2 + "encoder.layers." + layer1 + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif layer2 == "1":
new_key = (
prefix2
+ "encoder.layers."
+ layer1
+ ".post_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif op_name == "DenseReluDense":
if cfg.get("model_type") == "t5":
if keys[op_idx + 1] == "wi":
new_key = (
prefix2 + "encoder.layers." + layer1 + ".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif keys[op_idx + 1] == "wo":
new_key = (
prefix2 + "encoder.layers." + layer1 + ".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif cfg.get("model_type") == "mt5":
if keys[op_idx + 1] == "wi_0":
new_key = prefix2 + "encoder.layers." + layer1 + ".mlp.wi_0.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif keys[op_idx + 1] == "wi_1":
new_key = prefix2 + "encoder.layers." + layer1 + ".mlp.wi_1.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif keys[op_idx + 1] == "wo":
new_key = prefix2 + "encoder.layers." + layer1 + ".mlp.wo.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert T5's decoder Layers.
elif keys[encoder_decoder_idx] == "decoder":
if op_name == "SelfAttention":
new_key = (
prefix2
+ "decoder.layers."
+ layer1
+ ".self_attention.query_key_value.weight"
)
if new_key in oneflow_state_dict.keys():
continue
q_w = ".".join(keys[: op_idx + 1]) + ".q." + "weight"
k_w = ".".join(keys[: op_idx + 1]) + ".k." + "weight"
v_w = ".".join(keys[: op_idx + 1]) + ".v." + "weight"
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads, hidden_size)
oneflow_state_dict[new_key] = qkv_w
o_w = ".".join(keys[: op_idx + 1]) + ".o." + "weight"
new_key = prefix2 + "decoder.layers." + layer1 + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(o_w)
elif op_name == "layer_norm":
if layer2 == "0":
new_key = prefix2 + "decoder.layers." + layer1 + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif layer2 == "1":
new_key = (
prefix2
+ "decoder.layers."
+ layer1
+ ".post_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif layer2 == "2":
new_key = (
prefix2
+ "decoder.layers."
+ layer1
+ ".post_cross_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif op_name == "EncDecAttention":
new_key = prefix2 + "decoder.layers." + layer1 + ".cross_attention.query.weight"
if new_key in oneflow_state_dict.keys():
continue
q_w = ".".join(keys[: op_idx + 1]) + ".q." + "weight"
k_w = ".".join(keys[: op_idx + 1]) + ".k." + "weight"
v_w = ".".join(keys[: op_idx + 1]) + ".v." + "weight"
q_w = oneflow_state_dict.pop(q_w)
kv_w = flow.cat(
(
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
q_w = self._fix_qkv_ordering(q_w, head_size, num_heads, hidden_size)
kv_w = self._fix_qkv_ordering(kv_w, head_size, num_heads, hidden_size)
oneflow_state_dict[new_key] = q_w
new_key = (
prefix2 + "decoder.layers." + layer1 + ".cross_attention.key_value.weight"
)
oneflow_state_dict[new_key] = kv_w
o_w = ".".join(keys[: op_idx + 1]) + ".o." + "weight"
new_key = prefix2 + "decoder.layers." + layer1 + ".cross_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(o_w)
elif op_name == "DenseReluDense":
if cfg.get("model_type") == "t5":
if keys[op_idx + 1] == "wi":
new_key = (
prefix2 + "decoder.layers." + layer1 + ".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif keys[op_idx + 1] == "wo":
new_key = (
prefix2 + "decoder.layers." + layer1 + ".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif cfg.get("model_type") == "mt5":
if keys[op_idx + 1] == "wi_0":
new_key = prefix2 + "decoder.layers." + layer1 + ".mlp.wi_0.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif keys[op_idx + 1] == "wi_1":
new_key = prefix2 + "decoder.layers." + layer1 + ".mlp.wi_1.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif keys[op_idx + 1] == "wo":
new_key = prefix2 + "decoder.layers." + layer1 + ".mlp.wo.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
self._update_cfg("vocab_size", cfg_dict["vocab_size"])
self._update_cfg("hidden_size", cfg_dict["d_model"])
self._update_cfg("hidden_layers", cfg_dict["num_layers"])
self._update_cfg("num_attention_heads", cfg_dict["num_heads"])
self._update_cfg("intermediate_size", cfg_dict["d_ff"])
self._update_cfg("hidden_dropout_prob", cfg_dict["dropout_rate"])
self._update_cfg("attention_probs_dropout_prob", cfg_dict["dropout_rate"])
self._update_cfg(
"relative_attention_num_buckets", cfg_dict["relative_attention_num_buckets"]
)
self._update_cfg("embedding_dropout_prob", cfg_dict["dropout_rate"])
self._update_cfg("initializer_range", cfg_dict["initializer_factor"])
self._update_cfg("layernorm_eps", cfg_dict["layer_norm_epsilon"])
self._update_cfg("head_size", cfg_dict["d_kv"])
if "tie_word_embeddings" in self.libai_cfg:
self._update_cfg("tie_word_embeddings", cfg_dict.get("tie_word_embeddings", True))
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class T5LoaderLibai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "mt5_model"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment