Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
260ac7d9
Commit
260ac7d9
authored
Oct 15, 2019
by
Rémi Louf
Browse files
wip commit, switching computers
parent
fe25eefc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
21 deletions
+85
-21
examples/run_seq2seq_finetuning.py
examples/run_seq2seq_finetuning.py
+21
-21
examples/run_seq2seq_finetuning_test.py
examples/run_seq2seq_finetuning_test.py
+64
-0
No files found.
examples/run_seq2seq_finetuning.py
View file @
260ac7d9
...
@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
...
@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""
"""
import
argparse
import
argparse
import
deque
ue
from
collections
import
deque
import
logging
import
logging
import
pickle
import
pickle
import
random
import
random
...
@@ -57,9 +57,9 @@ class TextDataset(Dataset):
...
@@ -57,9 +57,9 @@ class TextDataset(Dataset):
CNN/Daily News:
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The
y consist in stories stored
The CNN/Daily News raw datasets are downloaded from [1]. The
stories are stored in different files; the summary appears at the end of the story as
in different files where the summary sentences are indicat
ed by the special `@highlight`
token.
sentences that are prefix
ed by the special `@highlight`
line. To process the
To process the
data, untar both datasets in the same folder, and pass the path to this
data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[1] https://cs.nyu.edu/~kcho/
...
@@ -69,7 +69,7 @@ class TextDataset(Dataset):
...
@@ -69,7 +69,7 @@ class TextDataset(Dataset):
assert
os
.
path
.
isdir
(
data_dir
)
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if present
# Load features that have already been computed if present
cached_features_file
=
os
.
path
.
join
(
directory
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
cached_features_file
=
os
.
path
.
join
(
directory
,
"cached_lm_{}_{}"
.
format
(
block_size
,
data_dir
)
)
if
os
.
path
.
exists
(
cached_features_file
):
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"rb"
)
as
source
:
with
open
(
cached_features_file
,
"rb"
)
as
source
:
...
@@ -86,18 +86,19 @@ class TextDataset(Dataset):
...
@@ -86,18 +86,19 @@ class TextDataset(Dataset):
stories_files
=
os
.
listdir
(
path_to_stories
)
stories_files
=
os
.
listdir
(
path_to_stories
)
for
story_file
in
stories_files
:
for
story_file
in
stories_files
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
"story_file"
)
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
"story_file"
)
if
!
os
.
path
.
isfile
(
path_to_story
):
if
not
os
.
path
.
isfile
(
path_to_story
):
continue
continue
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
try
:
try
:
story
,
summary
=
process_story
(
source
)
raw_story
=
source
.
read
()
story
,
summary
=
process_story
(
raw_story
)
except
IndexError
:
except
IndexError
:
continue
continue
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
story
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
story
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
summary
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
summary
))
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
blocksize
)
story_seq
,
summary_seq
=
_fit_to_block_size
(
story
,
summary
,
block
_
size
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
example
=
tokenizer
.
add_special_token_sequence_pair
(
story_seq
,
summary_seq
)
self
.
examples
.
append
(
example
)
self
.
examples
.
append
(
example
)
...
@@ -108,22 +109,22 @@ class TextDataset(Dataset):
...
@@ -108,22 +109,22 @@ class TextDataset(Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
examples
)
return
len
(
self
.
examples
)
def
__getitem__
(
self
):
def
__getitem__
(
self
,
items
):
return
torch
.
tensor
(
self
.
examples
[
items
])
return
torch
.
tensor
(
self
.
examples
[
items
])
def
process_story
(
story
_file
):
def
process_story
(
raw_
story
):
""" Process the text contained in a story file.
""" Process the text contained in a story file.
Returns the story and the summary
Returns the story and the summary
"""
"""
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
s
in
story
_file
]))
file_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_
story
.
split
(
"
\n
"
)
]))
# for some unknown reason some lines miss a period, add it
# for some unknown reason some lines miss a period, add it
file_lines
=
[
_add_missing_period
(
line
)
for
line
in
file_lines
]
file_lines
=
[
_add_missing_period
(
line
)
for
line
in
file_lines
]
# gather article lines
# gather article lines
story_lines
=
[]
story_lines
=
[]
lines
=
deque
ue
(
file_lines
)
lines
=
deque
(
file_lines
)
while
True
:
while
True
:
try
:
try
:
element
=
lines
.
popleft
()
element
=
lines
.
popleft
()
...
@@ -134,7 +135,7 @@ def process_story(story_file):
...
@@ -134,7 +135,7 @@ def process_story(story_file):
raise
ie
raise
ie
# gather summary lines
# gather summary lines
highlights_lines
=
list
(
filter
(
lambda
t
:
!
t
.
startswith
(
"@highlight"
),
lines
))
highlights_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
# join the lines
# join the lines
story
=
" "
.
join
(
story_lines
)
story
=
" "
.
join
(
story_lines
)
...
@@ -145,7 +146,7 @@ def process_story(story_file):
...
@@ -145,7 +146,7 @@ def process_story(story_file):
def
_add_missing_period
(
line
):
def
_add_missing_period
(
line
):
END_TOKENS
=
[
'.'
,
'!'
,
'?'
,
'...'
,
"'"
,
"`"
,
'"'
,
u
'
\u2019
'
,
u
'
\u2019
'
,
")"
]
END_TOKENS
=
[
'.'
,
'!'
,
'?'
,
'...'
,
"'"
,
"`"
,
'"'
,
u
'
\u2019
'
,
u
'
\u2019
'
,
")"
]
if
line
==
"@highlight"
:
if
line
.
startswith
(
"@highlight"
)
:
return
line
return
line
if
line
[
-
1
]
in
END_TOKENS
:
if
line
[
-
1
]
in
END_TOKENS
:
return
line
return
line
...
@@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
...
@@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
"""
"""
SRC_MAX_LENGTH
=
int
(
0.75
*
block_size
)
-
2
# CLS and EOS token
SRC_MAX_LENGTH
=
int
(
0.75
*
block_size
)
-
2
# CLS and EOS token
TGT_MAX_LENGTH
=
block_size
-
SRC_MAX_LENGTH
-
1
# EOS token
TGT_MAX_LENGTH
=
block_size
-
SRC_MAX_LENGTH
-
1
# EOS token
# we dump the examples that are too small to fit in the block size for the
# we dump the examples that are too small to fit in the block size for the
# sake of simplicity. You can modify this by adding model-specific padding.
# sake of simplicity. You can modify this by adding model-specific padding.
...
@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
...
@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
return
None
return
None
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
if
len
(
src_sequence
)
>
SRC_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_sequence
)
>
TGT_MAX_LENGTH
:
src_sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
src_sequence
=
src_sequence
[:
SRC_MAX_LENGTH
]
tgt_sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
tgt_sequence
=
tgt_sequence
[:
TGT_MAX_LENGTH
]
else
:
else
:
src_sequence
=
src_sequence
[
block_size
-
len
(
tgt_sequence
)
-
3
]
src_sequence
=
src_sequence
[
block_size
-
len
(
tgt_sequence
)
-
3
]
else
:
else
:
if
len
(
tgt_
tokens
)
>
TGT_MAX_LENGTH
:
if
len
(
tgt_
sequence
)
>
TGT_MAX_LENGTH
:
tgt_sequence
=
tgt_sequence
[
block_size
-
len
(
src_sequence
)
-
3
]
tgt_sequence
=
tgt_sequence
[
block_size
-
len
(
src_sequence
)
-
3
]
return
src_sequence
,
tgt_sequence
return
src_sequence
,
tgt_sequence
def
load_and_cache_examples
(
args
,
tokenizer
):
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
train_
data_
file
)
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
data_
dir
)
return
dataset
return
dataset
...
@@ -200,7 +200,7 @@ def main():
...
@@ -200,7 +200,7 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# Required parameters
# Required parameters
parser
.
add_argument
(
"--
train_
data_
file
"
,
parser
.
add_argument
(
"--data_
dir
"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
...
...
examples/run_seq2seq_finetuning_test.py
0 → 100644
View file @
260ac7d9
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
.run_seq2seq_finetuning
import
process_story
,
_fit_to_block_size
class
DataLoaderTest
(
unittest
.
TestCase
):
def
__init__
(
self
,
block_size
=
10
):
self
.
block_size
=
block_size
def
source_and_target_too_small
(
self
):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
]
self
.
assertEqual
(
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
),
None
)
def
source_and_target_fit_exactly
(
self
):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
fitted_src
)
self
.
assertListEqual
(
tgt_seq
==
fitted_tgt
)
def
source_too_big_target_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
]
tgt_seq
=
[
1
,
2
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
tgt_seq
==
fitted_tgt
)
def
target_too_big_source_ok
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
]
tgt_seq
=
[
1
,
2
,
3
,
4
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
src_seq
)
self
.
assertListEqual
(
tgt_seq
==
[
1
,
2
,
3
])
def
source_and_target_too_big
(
self
):
src_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
tgt_seq
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
fitted_src
,
fitted_tgt
=
_fit_to_block_size
(
src_seq
,
tgt_seq
,
self
.
block_size
)
self
.
assertListEqual
(
src_seq
==
[
1
,
2
,
3
,
4
,
5
])
self
.
assertListEqual
(
tgt_seq
==
[
1
,
2
])
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment