Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fa37b4da
Commit
fa37b4da
authored
May 09, 2019
by
burcturkoglu
Browse files
Merge branch 'master' of
https://github.com/huggingface/pytorch-pretrained-BERT
parents
5289b4b9
701bd59b
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
161 additions
and
144 deletions
+161
-144
examples/lm_finetuning/simple_lm_finetuning.py
examples/lm_finetuning/simple_lm_finetuning.py
+29
-28
examples/run_classifier.py
examples/run_classifier.py
+30
-29
examples/run_openai_gpt.py
examples/run_openai_gpt.py
+14
-13
examples/run_squad.py
examples/run_squad.py
+34
-33
examples/run_swag.py
examples/run_swag.py
+33
-32
hubconf.py
hubconf.py
+1
-1
pytorch_pretrained_bert/file_utils.py
pytorch_pretrained_bert/file_utils.py
+12
-3
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+8
-5
No files found.
examples/lm_finetuning/simple_lm_finetuning.py
View file @
fa37b4da
...
...
@@ -534,6 +534,7 @@ def main():
model
=
torch
.
nn
.
DataParallel
(
model
)
# Prepare optimizer
if
args
.
do_train
:
param_optimizer
=
list
(
model
.
named_parameters
())
no_decay
=
[
'bias'
,
'LayerNorm.bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
...
...
examples/run_classifier.py
View file @
fa37b4da
...
...
@@ -271,7 +271,7 @@ class StsbProcessor(DataProcessor):
class
QqpProcessor
(
DataProcessor
):
"""Processor for the
STS-B
data set (GLUE version)."""
"""Processor for the
QQP
data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -306,7 +306,7 @@ class QqpProcessor(DataProcessor):
class
QnliProcessor
(
DataProcessor
):
"""Processor for the
STS-B
data set (GLUE version)."""
"""Processor for the
QNLI
data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -763,6 +763,7 @@ def main():
model
=
torch
.
nn
.
DataParallel
(
model
)
# Prepare optimizer
if
args
.
do_train
:
param_optimizer
=
list
(
model
.
named_parameters
())
no_decay
=
[
'bias'
,
'LayerNorm.bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
...
...
examples/run_openai_gpt.py
View file @
fa37b4da
...
...
@@ -183,6 +183,7 @@ def main():
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
# Prepare optimizer
if
args
.
do_train
:
param_optimizer
=
list
(
model
.
named_parameters
())
no_decay
=
[
'bias'
,
'LayerNorm.bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
...
...
examples/run_squad.py
View file @
fa37b4da
...
...
@@ -922,6 +922,7 @@ def main():
model
=
torch
.
nn
.
DataParallel
(
model
)
# Prepare optimizer
if
args
.
do_train
:
param_optimizer
=
list
(
model
.
named_parameters
())
# hack to remove pooler, which is not used
...
...
examples/run_swag.py
View file @
fa37b4da
...
...
@@ -385,6 +385,7 @@ def main():
model
=
torch
.
nn
.
DataParallel
(
model
)
# Prepare optimizer
if
args
.
do_train
:
param_optimizer
=
list
(
model
.
named_parameters
())
# hack to remove pooler, which is not used
...
...
hubconf.py
View file @
fa37b4da
...
...
@@ -84,7 +84,7 @@ def bertTokenizer(*args, **kwargs):
Example:
>>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('
ailzhang
/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> tokenizer = torch.hub.load('
huggingface
/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks)
...
...
pytorch_pretrained_bert/file_utils.py
View file @
fa37b4da
...
...
@@ -22,6 +22,15 @@ import requests
from
botocore.exceptions
import
ClientError
from
tqdm
import
tqdm
try
:
from
torch.hub
import
_get_torch_home
torch_cache_home
=
_get_torch_home
()
except
ImportError
:
torch_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
'TORCH_HOME'
,
os
.
path
.
join
(
os
.
getenv
(
'XDG_CACHE_HOME'
,
'~/.cache'
),
'torch'
)))
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
'pytorch_pretrained_bert'
)
try
:
from
urllib.parse
import
urlparse
except
ImportError
:
...
...
@@ -29,11 +38,11 @@ except ImportError:
try
:
from
pathlib
import
Path
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
Path
.
home
()
/
'.pytor
ch_p
retrained_bert'
))
PYTORCH_PRETRAINED_BERT_CACHE
=
Path
(
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
default_ca
ch
e
_p
ath
))
except
(
AttributeError
,
ImportError
):
PYTORCH_PRETRAINED_BERT_CACHE
=
os
.
getenv
(
'PYTORCH_PRETRAINED_BERT_CACHE'
,
os
.
path
.
join
(
os
.
path
.
expanduser
(
"~"
),
'.pytor
ch_p
retrained_bert'
)
)
default_ca
ch
e
_p
ath
)
CONFIG_NAME
=
"config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
...
...
pytorch_pretrained_bert/modeling.py
View file @
fa37b4da
...
...
@@ -145,7 +145,8 @@ class BertConfig(object):
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
):
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
"""Constructs BertConfig.
Args:
...
...
@@ -169,6 +170,7 @@ class BertConfig(object):
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -188,6 +190,7 @@ class BertConfig(object):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -254,7 +257,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
...
...
@@ -329,7 +332,7 @@ class BertSelfOutput(nn.Module):
def
__init__
(
self
,
config
):
super
(
BertSelfOutput
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
,
input_tensor
):
...
...
@@ -370,7 +373,7 @@ class BertOutput(nn.Module):
def
__init__
(
self
,
config
):
super
(
BertOutput
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
,
input_tensor
):
...
...
@@ -434,7 +437,7 @@ class BertPredictionHeadTransform(nn.Module):
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
transform_act_fn
=
config
.
hidden_act
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment