Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de713fa9
Commit
de713fa9
authored
Jun 20, 2019
by
thomwolf
Browse files
starting
parent
c304593d
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
629 additions
and
2 deletions
+629
-2
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+62
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+1
-1
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+1
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+565
-0
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+0
-0
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
0 → 100755
View file @
de713fa9
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
load_tf_weights_in_xlnet
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetModel
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--xlnet_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_path
)
pytorch_pretrained_bert/modeling.py
View file @
de713fa9
...
@@ -718,7 +718,7 @@ class BertPreTrainedModel(nn.Module):
...
@@ -718,7 +718,7 @@ class BertPreTrainedModel(nn.Module):
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_bert
(
model
,
weights_path
)
return
load_tf_weights_in_bert
(
model
,
resolved_archive_file
)
# Load from a PyTorch state_dict
# Load from a PyTorch state_dict
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
de713fa9
...
@@ -236,7 +236,7 @@ class TransfoXLConfig(object):
...
@@ -236,7 +236,7 @@ class TransfoXLConfig(object):
dropout: The dropout probabilitiy for all fully connected
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention probabilities.
dropatt: The dropout ratio for the attention probabilities.
untie_r: untie relative position biases
untie_r: untie relative position biases
embd_pdrop: The dropout ratio for the embeddings.
embd_pdrop: The dropout ratio for the embeddings.
init: parameter initializer to use
init: parameter initializer to use
init_range: parameters initialized by U(-init_range, init_range).
init_range: parameters initialized by U(-init_range, init_range).
...
...
pytorch_pretrained_bert/modeling_xlnet.py
0 → 100644
View file @
de713fa9
This diff is collapsed.
Click to expand it.
pytorch_pretrained_bert/tokenization_xlnet.py
0 → 100644
View file @
de713fa9
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment