Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fafd4c86
Commit
fafd4c86
authored
Dec 11, 2019
by
thomwolf
Browse files
fix TF 2.0 version of T5 - update conversion script
parent
67a8be8e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
65 additions
and
31 deletions
+65
-31
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+4
-7
transformers/file_utils.py
transformers/file_utils.py
+3
-0
transformers/modeling_t5.py
transformers/modeling_t5.py
+18
-3
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+28
-15
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+2
-4
transformers/modeling_utils.py
transformers/modeling_utils.py
+10
-2
No files found.
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
fafd4c86
...
@@ -120,24 +120,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -120,24 +120,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
if
compare_with_pt_model
:
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tfo
=
tf_model
(
tf_model
.
dummy_inputs
,
training
=
False
)
# build the network
tf_inputs
=
tf_model
.
dummy_inputs
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
)
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
config
=
config
,
state_dict
=
state_dict
)
state_dict
=
state_dict
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pto
=
pt_model
(
pt
_inputs
)
pto
=
pt_model
(
**
pt_model
.
dummy
_inputs
)
np_pt
=
pto
[
0
].
detach
().
numpy
()
np_pt
=
pto
[
0
].
numpy
()
np_tf
=
tfo
[
0
].
numpy
()
np_tf
=
tfo
[
0
].
numpy
()
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
assert
diff
<=
2e-2
,
"Error, model absolute difference is >2e-2
"
assert
diff
<=
2e-2
,
"Error, model absolute difference is >2e-2
: {}"
.
format
(
diff
)
# Save pytorch-model
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
...
...
transformers/file_utils.py
View file @
fafd4c86
...
@@ -73,6 +73,9 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
...
@@ -73,6 +73,9 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
CONFIG_NAME
=
"config.json"
CONFIG_NAME
=
"config.json"
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
DUMMY_MASK
=
[[
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
]]
def
is_torch_available
():
def
is_torch_available
():
return
_torch_available
return
_torch_available
...
...
transformers/modeling_t5.py
View file @
fafd4c86
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
from
.configuration_t5
import
T5Config
from
.configuration_t5
import
T5Config
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -451,6 +451,15 @@ class T5PreTrainedModel(PreTrainedModel):
...
@@ -451,6 +451,15 @@ class T5PreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_t5
load_tf_weights
=
load_tf_weights_in_t5
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
@
property
def
dummy_inputs
(
self
):
input_ids
=
torch
.
tensor
(
DUMMY_INPUTS
)
input_mask
=
torch
.
tensor
(
DUMMY_MASK
)
dummy_inputs
=
{
'decoder_input_ids'
:
input_ids
,
'encoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
return
dummy_inputs
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights """
""" Initialize the weights """
factor
=
self
.
config
.
initializer_factor
# Used for testing weights initialization
factor
=
self
.
config
.
initializer_factor
# Used for testing weights initialization
...
@@ -534,9 +543,10 @@ class T5Stack(T5PreTrainedModel):
...
@@ -534,9 +543,10 @@ class T5Stack(T5PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we simulate this here with this transposi
s
tion
# T5 has a mask that can compare sequence ids, we
can
simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
extended_attention_mask
=
(
extended_attention_mask
==
extended_attention_mask
.
transpose
(
-
1
,
-
2
))
# extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
...
@@ -548,6 +558,10 @@ class T5Stack(T5PreTrainedModel):
...
@@ -548,6 +558,10 @@ class T5Stack(T5PreTrainedModel):
if
encoder_attention_mask
.
dim
()
==
2
:
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
else
:
...
@@ -590,6 +604,7 @@ class T5Stack(T5PreTrainedModel):
...
@@ -590,6 +604,7 @@ class T5Stack(T5PreTrainedModel):
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
if
i
==
0
:
# We share the position biases between the layers - the first layer store them
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
if
self
.
is_decoder
:
if
self
.
is_decoder
:
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
...
...
transformers/modeling_tf_t5.py
View file @
fafd4c86
...
@@ -26,7 +26,7 @@ import tensorflow as tf
...
@@ -26,7 +26,7 @@ import tensorflow as tf
from
.configuration_t5
import
T5Config
from
.configuration_t5
import
T5Config
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -61,7 +61,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
...
@@ -61,7 +61,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
super
(
TFT5LayerNorm
,
self
).
build
(
input_shape
)
super
(
TFT5LayerNorm
,
self
).
build
(
input_shape
)
def
call
(
self
,
x
):
def
call
(
self
,
x
):
variance
=
tf
.
math
.
reduce_m
i
n
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
variance
=
tf
.
math
.
reduce_m
ea
n
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
x
=
x
*
tf
.
math
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
tf
.
math
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
return
self
.
weight
*
x
...
@@ -231,19 +231,19 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -231,19 +231,19 @@ class TFT5Attention(tf.keras.layers.Layer):
cache
[
self
.
layer_id
]
=
(
k
,
v
)
cache
[
self
.
layer_id
]
=
(
k
,
v
)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
# (bs, n_heads, qlen, klen)
# scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
scores
=
tf
.
einsum
(
'bnqd,bnkd->bnqk'
,
q
,
k
)
# (bs, n_heads, qlen, klen)
if
position_bias
is
None
:
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
if
not
self
.
has_relative_attention_bias
:
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
if
mask
is
not
None
:
position_bias
=
position_bias
+
mask
if
mask
is
not
None
:
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
scores
+=
mask
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
scores
+=
position_bias
weights
=
tf
.
nn
.
softmax
(
scores
,
axis
=-
1
)
# (bs, n_heads, qlen, klen)
weights
=
tf
.
nn
.
softmax
(
scores
,
axis
=-
1
)
# (bs, n_heads, qlen, klen)
weights
=
self
.
dropout
(
weights
,
training
=
training
)
# (bs, n_heads, qlen, klen)
weights
=
self
.
dropout
(
weights
,
training
=
training
)
# (bs, n_heads, qlen, klen)
...
@@ -350,11 +350,11 @@ class TFT5Block(tf.keras.layers.Layer):
...
@@ -350,11 +350,11 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask
=
head_mask
,
head_mask
=
head_mask
,
training
=
training
)
training
=
training
)
hidden_states
=
cross_attention_outputs
[
0
]
hidden_states
=
cross_attention_outputs
[
0
]
outputs
=
cross_attention_outputs
[
1
:]
+
outputs
outputs
=
outputs
+
cross_attention_outputs
[
1
:]
hidden_states
=
self
.
layer
[
2
](
hidden_states
,
training
=
training
)
hidden_states
=
self
.
layer
[
2
](
hidden_states
,
training
=
training
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
return
outputs
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
####################################################
####################################################
...
@@ -418,7 +418,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -418,7 +418,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions.
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# extended_attention_mask = tf.math.equal(extended_attention_mask,
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
if
self
.
is_decoder
:
if
self
.
is_decoder
:
# If a 2D ou 3D attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
...
@@ -430,7 +436,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -430,7 +436,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if
num_dims_encoder_attention_mask
==
2
:
if
num_dims_encoder_attention_mask
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
else
:
encoder_extended_attention_mask
=
None
encoder_extended_attention_mask
=
None
...
@@ -463,6 +474,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
...
@@ -463,6 +474,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training
=
training
)
training
=
training
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
if
i
==
0
:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
if
self
.
is_decoder
:
if
self
.
is_decoder
:
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
...
@@ -502,8 +515,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
...
@@ -502,8 +515,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
@
property
@
property
def
dummy_inputs
(
self
):
def
dummy_inputs
(
self
):
input_ids
=
tf
.
constant
(
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
)
input_ids
=
tf
.
constant
(
DUMMY_INPUTS
)
input_mask
=
tf
.
constant
(
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
)
input_mask
=
tf
.
constant
(
DUMMY_MASK
)
dummy_inputs
=
{
'decoder_input_ids'
:
input_ids
,
dummy_inputs
=
{
'decoder_input_ids'
:
input_ids
,
'encoder_input_ids'
:
input_ids
,
'encoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
'decoder_attention_mask'
:
input_mask
}
...
...
transformers/modeling_tf_utils.py
View file @
fafd4c86
...
@@ -24,13 +24,11 @@ import os
...
@@ -24,13 +24,11 @@ import os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
DUMMY_INPUTS
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
r
""" Base class for all TF models.
r
""" Base class for all TF models.
...
@@ -59,7 +57,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -59,7 +57,7 @@ class TFPreTrainedModel(tf.keras.Model):
Returns:
Returns:
tf.Tensor with dummy inputs
tf.Tensor with dummy inputs
"""
"""
return
tf
.
constant
(
DUMMY_INPUTS
)
return
{
'input_ids'
:
tf
.
constant
(
DUMMY_INPUTS
)
}
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
...
...
transformers/modeling_utils.py
View file @
fafd4c86
...
@@ -31,11 +31,10 @@ from torch.nn import CrossEntropyLoss
...
@@ -31,11 +31,10 @@ from torch.nn import CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
DUMMY_INPUTS
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
try
:
from
torch.nn
import
Identity
from
torch.nn
import
Identity
except
ImportError
:
except
ImportError
:
...
@@ -71,6 +70,15 @@ class PreTrainedModel(nn.Module):
...
@@ -71,6 +70,15 @@ class PreTrainedModel(nn.Module):
load_tf_weights
=
lambda
model
,
config
,
path
:
None
load_tf_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
base_model_prefix
=
""
@
property
def
dummy_inputs
(
self
):
""" Dummy inputs to do a forward pass in the network.
Returns:
torch.Tensor with dummy inputs
"""
return
{
'input_ids'
:
torch
.
tensor
(
DUMMY_INPUTS
)}
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
PreTrainedModel
,
self
).
__init__
()
super
(
PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
PretrainedConfig
):
if
not
isinstance
(
config
,
PretrainedConfig
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment