Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
834b485b
"tests/models/codegen/test_modeling_codegen.py" did not exist on "29c10a41d04f855c433a6cde7797b325651417d2"
Commit
834b485b
authored
Nov 04, 2018
by
thomwolf
Browse files
logging + update copyright
parent
1701291e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
36 deletions
+38
-36
convert_tf_checkpoint_to_pytorch.py
convert_tf_checkpoint_to_pytorch.py
+13
-0
extract_features.py
extract_features.py
+2
-2
modeling.py
modeling.py
+2
-2
optimization.py
optimization.py
+16
-0
run_classifier.py
run_classifier.py
+1
-1
run_squad.py
run_squad.py
+3
-30
tokenization.py
tokenization.py
+1
-1
No files found.
convert_tf_checkpoint_to_pytorch.py
View file @
834b485b
# coding=utf-8
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
extract_features.py
View file @
834b485b
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright 2018 The Google AI Language Team Authors
and The HugginFace Inc. team
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Extract pre-computed feature vectors from
BERT
."""
"""Extract pre-computed feature vectors from
a PyTorch BERT model
."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
modeling.py
View file @
834b485b
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright 2018 The Google AI Language Team Authors
and The HugginFace Inc. team
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Common utility functions related to TensorFlow
."""
"""
PyTorch BERT model
."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
optimization.py
View file @
834b485b
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import
math
import
math
import
torch
import
torch
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
...
run_classifier.py
View file @
834b485b
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright 2018 The Google AI Language Team Authors
and The HugginFace Inc. team
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
run_squad.py
View file @
834b485b
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright 2018 The Google AI Language Team Authors
and The HugginFace Inc. team
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -720,22 +720,6 @@ def main():
...
@@ -720,22 +720,6 @@ def main():
help
=
"The maximum length of an answer that can be generated. This is needed because the start "
help
=
"The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another."
)
"and end predictions are not conditioned on one another."
)
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
# parser.add_argument("--use_tpu", default=False, action='store_true', help="Whether to use TPU or GPU/CPU.")
# parser.add_argument("--tpu_name", default=None, type=str,
# help="The Cloud TPU to use for training. This should be either the name used when creating the "
# "Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
# parser.add_argument("--tpu_zone", default=None, type=str,
# help="[Optional] GCE zone where the Cloud TPU is located in. If not specified, we will attempt "
# "to automatically detect the GCE project from metadata.")
# parser.add_argument("--gcp_project", default=None, type=str,
# help="[Optional] Project name for the Cloud TPU-enabled project. If not specified, we will attempt "
# "to automatically detect the GCE project from metadata.")
# parser.add_argument("--master", default=None, type=str, help="[Optional] TensorFlow master URL.")
# parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if `use_tpu` is True. "
# "Total number of TPU cores to use.")
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
parser
.
add_argument
(
"--verbose_logging"
,
default
=
False
,
action
=
'store_true'
,
parser
.
add_argument
(
"--verbose_logging"
,
default
=
False
,
action
=
'store_true'
,
help
=
"If true, all of the warnings related to data processing will be printed. "
help
=
"If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation."
)
"A number of warnings are expected for a normal SQuAD evaluation."
)
...
@@ -871,7 +855,6 @@ def main():
...
@@ -871,7 +855,6 @@ def main():
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
global_step
+=
1
global_step
+=
1
logger
.
info
(
"Done %s steps"
,
global_step
)
if
args
.
do_predict
:
if
args
.
do_predict
:
eval_examples
=
read_squad_examples
(
eval_examples
=
read_squad_examples
(
...
@@ -892,10 +875,8 @@ def main():
...
@@ -892,10 +875,8 @@ def main():
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
#all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
#eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
)
if
args
.
local_rank
==
-
1
:
if
args
.
local_rank
==
-
1
:
eval_sampler
=
SequentialSampler
(
eval_data
)
eval_sampler
=
SequentialSampler
(
eval_data
)
...
@@ -906,7 +887,6 @@ def main():
...
@@ -906,7 +887,6 @@ def main():
model
.
eval
()
model
.
eval
()
all_results
=
[]
all_results
=
[]
logger
.
info
(
"Start evaluating"
)
logger
.
info
(
"Start evaluating"
)
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
for
input_ids
,
input_mask
,
segment_ids
,
example_index
in
tqdm
(
eval_dataloader
,
descr
=
"Evaluating"
):
for
input_ids
,
input_mask
,
segment_ids
,
example_index
in
tqdm
(
eval_dataloader
,
descr
=
"Evaluating"
):
if
len
(
all_results
)
%
1000
==
0
:
if
len
(
all_results
)
%
1000
==
0
:
logger
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
logger
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
...
@@ -918,9 +898,7 @@ def main():
...
@@ -918,9 +898,7 @@ def main():
start_logits
,
end_logits
=
model
(
input_ids
,
segment_ids
,
input_mask
)
start_logits
,
end_logits
=
model
(
input_ids
,
segment_ids
,
input_mask
)
unique_id
=
[
int
(
eval_features
[
e
.
item
()].
unique_id
)
for
e
in
example_index
]
unique_id
=
[
int
(
eval_features
[
e
.
item
()].
unique_id
)
for
e
in
example_index
]
#start_logits = [x.item() for x in start_logits]
start_logits
=
[
x
.
view
(
-
1
).
detach
().
cpu
().
numpy
()
for
x
in
start_logits
]
start_logits
=
[
x
.
view
(
-
1
).
detach
().
cpu
().
numpy
()
for
x
in
start_logits
]
#end_logits = [x.item() for x in end_logits]
end_logits
=
[
x
.
view
(
-
1
).
detach
().
cpu
().
numpy
()
for
x
in
end_logits
]
end_logits
=
[
x
.
view
(
-
1
).
detach
().
cpu
().
numpy
()
for
x
in
end_logits
]
for
idx
,
i
in
enumerate
(
unique_id
):
for
idx
,
i
in
enumerate
(
unique_id
):
s
=
[
float
(
x
)
for
x
in
start_logits
[
idx
]]
s
=
[
float
(
x
)
for
x
in
start_logits
[
idx
]]
...
@@ -932,11 +910,6 @@ def main():
...
@@ -932,11 +910,6 @@ def main():
end_logits
=
e
end_logits
=
e
)
)
)
)
# all_results.append(
# RawResult(
# unique_id=unique_id,
# start_logits=start_logits,
# end_logits=end_logits))
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions.json"
)
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions.json"
)
...
...
tokenization.py
View file @
834b485b
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright 2018 The Google AI Language Team Authors
and The HugginFace Inc. team
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment