Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
260ac7d9
Commit
260ac7d9
authored
Oct 15, 2019
by
Rémi Louf
Browse files
wip commit, switching computers
parent
fe25eefc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
21 deletions
+85
-21
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+21
-21
examples/run_seq2seq_finetuning_test.py
examples/run_seq2seq_finetuning_test.py
+64
-0
No files found.
examples/run_seq2seq_finetuning.py
View file @
260ac7d9
...
...
@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""
import
argparse
import
deque
ue
from
collections
import
deque
import
logging
import
pickle
import
random
...
...
@@ -57,9 +57,9 @@ class TextDataset(Dataset):
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The
y consist in stories stored
in different files where the summary sentences are indicat
ed by the special `@highlight`
token.
To process the
data, untar both datasets in the same folder, and pass the path to this
The CNN/Daily News raw datasets are downloaded from [1]. The
stories are stored in different files; the summary appears at the end of the story as
sentences that are prefix
ed by the special `@highlight`
line. To process the
data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
...
...
@@ -69,7 +69,7 @@ class TextDataset(Dataset):
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
cached_features_file
=
os
.
path
.
join
(
directory
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
cached_features_file
=
os
.
path
.
join
(
directory
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"rb"
)
as
source
:
...
...
@@ -86,18 +86,19 @@ class TextDataset(Dataset):
stories_files
=
os
.
listdir
(
path_to_stories
)
for
story_file
in
stories_files
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
"story_file"
)
if
!
os
.
path
.
isfile
(
path_to_story
):
if
not
os
.
path
.
isfile
(
path_to_story
):
continue
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
try
:
story
,
summary
=
process_story
(
source
)
raw_story
=
source
.
read
()
story
,
summary
=
process_story
(
raw_story
)
except
IndexError
:
continue
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
blocksize
)
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
block
_
size
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
self
.
examples
.
append
(
example
)
...
...
@@ -108,22 +109,22 @@ class TextDataset(Dataset):
def
__len__
(
self
):
return
len
(
self
.
examples
)
def
__getitem__
(
self
):
def
__getitem__
(
self
,
items
):
return
torch
.
tensor
(
self
.
examples
[
items
])
def
process_story
(
story
_file
):
def
process_story
(
raw_
story
):
""" Process the text contained in a story file.
Returns the story and the summary
"""
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
s
in
story
_file
]))
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_
story
.
split
(
"
\n
"
)
]))
# for some unknown reason some lines miss a period, add it
file_lines
=
[
_add_missing_period
(
line
)
for
line
in
file_lines
]
# gather article lines
story_lines
=
[]
lines
=
deque
ue
(
file_lines
)
lines
=
deque
(
file_lines
)
while
True
:
try
:
element
=
lines
.
popleft
()
...
...
@@ -134,7 +135,7 @@ def process_story(story_file):
raise
ie
# gather summary lines
highlights_lines
=
list
(
filter
(
lambda
t
:
!
t
.
startswith
(
"@highlight"
),
lines
))
highlights_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
# join the lines
story
=
" "
.
join
(
story_lines
)
...
...
@@ -145,7 +146,7 @@ def process_story(story_file):
def
_add_missing_period
(
line
):
END_TOKENS
=
[
'.'
,
'!'
,
'?'
,
'...'
,
"'"
,
"`"
,
'"'
,
u
'
\u2019
'
,
u
'
\u2019
'
,
")"
]
if
line
==
"@highlight"
:
if
line
.
startswith
(
"@highlight"
)
:
return
line
if
line
[
-
1
]
in
END_TOKENS
:
return
line
...
...
@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
return
None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
src_sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
tgt_sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
else
:
src_sequence
=
src_sequence
[
block_size
-
len
(
tgt_sequence
)
-
3
]
else
:
if
len
(
tgt_
tokens
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_
sequence
)
>
TGT_MAX_LENGTH
:
tgt_sequence
=
tgt_sequence
[
block_size
-
len
(
src_sequence
)
-
3
]
return
src_sequence
,
tgt_sequence
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
train_
data_
file
)
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
data_
dir
)
return
dataset
...
...
@@ -200,7 +200,7 @@ def main():
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--
train_
data_
file
"
,
parser
.
add_argument
(
"--data_
dir
"
,
default
=
None
,
type
=
str
,
required
=
True
,
...
...
examples/run_seq2seq_finetuning_test.py
0 → 100644
View file @
260ac7d9
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
.run_seq2seq_finetuning
import
process_story
,
_fit_to_block_size
class
DataLoaderTest
(
unittest
.
TestCase
):
def
__init__
(
self
,
block_size
=
10
):
self
.
block_size
=
block_size
def
source_and_target_too_small
(
self
):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
]
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
def
source_and_target_fit_exactly
(
self
):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
fitted_src
)
self
.
assertListEqual
(
tgt_seq
==
fitted_tgt
)
def
source_too_big_target_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
tgt_seq
=
[
1
,
2
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
tgt_seq
==
fitted_tgt
)
def
target_too_big_source_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
src_seq
)
self
.
assertListEqual
(
tgt_seq
==
[
1
,
2
,
3
])
def
source_and_target_too_big
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
tgt_seq
==
[
1
,
2
])
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment