Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c6207d85
Commit
c6207d85
authored
Nov 04, 2018
by
thomwolf
Browse files
remove old methods
parent
965b2565
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
31 deletions
+0
-31
run_classifier.py
run_classifier.py
+0
-31
No files found.
run_classifier.py
View file @
c6207d85
...
@@ -305,37 +305,6 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
...
@@ -305,37 +305,6 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else
:
else
:
tokens_b
.
pop
()
tokens_b
.
pop
()
def
input_fn_builder
(
features
,
seq_length
,
train_batch_size
):
# TODO: delete
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
### ATTENTION - To rewrite ###
all_input_ids
=
[
f
.
input_ids
for
feature
in
features
]
all_input_mask
=
[
f
.
input_mask
for
feature
in
features
]
all_segment_ids
=
[
f
.
segment_ids
for
feature
in
features
]
all_label_ids
=
[
f
.
label_id
for
feature
in
features
]
# for feature in features:
# all_input_ids.append(feature.input_ids)
# all_input_mask.append(feature.input_mask)
# all_segment_ids.append(feature.segment_ids)
# all_label_ids.append(feature.label_id)
input_ids_tensor
=
torch
.
tensor
(
all_input_ids
,
dtype
=
torch
.
Long
)
input_mask_tensor
=
torch
.
tensor
(
all_input_mask
,
dtype
=
torch
.
Long
)
segment_tensor
=
torch
.
tensor
(
all_segment_ids
,
dtype
=
torch
.
Long
)
label_tensor
=
torch
.
tensor
(
all_label_ids
,
dtype
=
torch
.
Long
)
train_data
=
TensorDataset
(
input_ids_tensor
,
input_mask_tensor
,
segment_tensor
,
label_tensor
)
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
train_data
)
else
:
train_sampler
=
DistributedSampler
(
train_data
)
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
train_batch_size
)
return
train_dataloader
def
accuracy
(
out
,
labels
):
def
accuracy
(
out
,
labels
):
outputs
=
np
.
argmax
(
out
,
axis
=
1
)
outputs
=
np
.
argmax
(
out
,
axis
=
1
)
return
np
.
sum
(
outputs
==
labels
)
return
np
.
sum
(
outputs
==
labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment