Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b7f1b050
Commit
b7f1b050
authored
Apr 14, 2020
by
Neel Kant
Browse files
Lint whole repo
parent
c99fa80c
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
62 deletions
+57
-62
tasks/race/data.py
tasks/race/data.py
+0
-3
tasks/zeroshot_gpt2/datasets.py
tasks/zeroshot_gpt2/datasets.py
+12
-12
tasks/zeroshot_gpt2/detokenizer.py
tasks/zeroshot_gpt2/detokenizer.py
+45
-47
No files found.
tasks/race/data.py
View file @
b7f1b050
...
...
@@ -39,16 +39,13 @@ class RaceDataset(Dataset):
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
return
self
.
samples
[
idx
]
def
process_single_datapath
(
datapath
,
tokenizer
,
max_qa_length
,
max_seq_length
):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
...
...
tasks/zeroshot_gpt2/datasets.py
View file @
b7f1b050
...
...
@@ -64,12 +64,12 @@ class _LMDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
start_idx
=
idx
*
self
.
overalapping_eval
end_idx
=
start_idx
+
self
.
seq_len
tokens
=
self
.
tokens
[
start_idx
:
end_idx
+
1
]
tokens
=
self
.
tokens
[
start_idx
:
end_idx
+
1
]
num_tokens
=
len
(
tokens
)
pad_mask
=
[
1
]
*
num_tokens
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
pad_mask
=
[
1
]
*
num_tokens
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
tokens
+=
[
self
.
pad_idx
]
*
num_pad
pad_mask
=
np
.
array
(
pad_mask
[
1
:])
if
self
.
overalapping_eval
!=
self
.
seq_len
and
idx
!=
0
:
...
...
@@ -103,7 +103,7 @@ class _LambadaDataset(torch.utils.data.Dataset):
last_token
=
text
.
split
()[
-
1
]
start_idx
=
text
.
rfind
(
last_token
)
beginning_tokens
=
self
.
tokenizer
.
tokenize
(
text
[:
start_idx
].
strip
())
last_token
=
self
.
tokenizer
.
tokenize
(
' '
+
last_token
)
last_token
=
self
.
tokenizer
.
tokenize
(
' '
+
last_token
)
return
beginning_tokens
,
last_token
def
__len__
(
self
):
...
...
@@ -112,14 +112,14 @@ class _LambadaDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
tokens
=
self
.
tokens
[
idx
]
num_tokens
=
len
(
tokens
)
pad_mask
=
[
0
]
*
num_tokens
pad_mask
=
[
0
]
*
num_tokens
labels
=
self
.
labels
[
idx
]
pad_mask
+=
[
1
]
*
len
(
labels
)
tokens
=
tokens
+
labels
pad_mask
+=
[
1
]
*
len
(
labels
)
tokens
=
tokens
+
labels
num_tokens
=
len
(
tokens
)
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
if
num_tokens
<
self
.
seq_len
+
1
:
num_pad
=
(
self
.
seq_len
+
1
-
num_tokens
)
pad_mask
+=
[
0
]
*
(
num_pad
)
tokens
+=
[
self
.
pad_idx
]
*
num_pad
pad_mask
=
np
.
array
(
pad_mask
[
1
:])
...
...
tasks/zeroshot_gpt2/detokenizer.py
View file @
b7f1b050
...
...
@@ -19,64 +19,62 @@ import re
def
ptb_detokenizer
(
string
):
string
=
string
.
replace
(
" '"
,
"'"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" n't"
,
"n't"
)
string
=
string
.
replace
(
" N "
,
"1 "
)
string
=
string
.
replace
(
"$ 1"
,
"$1"
)
string
=
string
.
replace
(
"# 1"
,
"#1"
)
return
string
string
=
string
.
replace
(
" '"
,
"'"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" n't"
,
"n't"
)
string
=
string
.
replace
(
" N "
,
"1 "
)
string
=
string
.
replace
(
"$ 1"
,
"$1"
)
string
=
string
.
replace
(
"# 1"
,
"#1"
)
return
string
def
wikitext_detokenizer
(
string
):
#
contractions
string
=
string
.
replace
(
"s '"
,
"s'"
)
string
=
re
.
sub
(
r
"/' [0-9]/"
,
r
"/'[0-9]/"
,
string
)
# number separators
string
=
string
.
replace
(
" @-@ "
,
"-"
)
string
=
string
.
replace
(
" @,@ "
,
","
)
string
=
string
.
replace
(
" @.@ "
,
"."
)
#
punctuation
string
=
string
.
replace
(
" : "
,
": "
)
string
=
string
.
replace
(
" ; "
,
"; "
)
string
=
string
.
replace
(
" . "
,
". "
)
string
=
string
.
replace
(
" ! "
,
"! "
)
string
=
string
.
replace
(
" ? "
,
"? "
)
string
=
string
.
replace
(
" , "
,
", "
)
# double brackets
string
=
re
.
sub
(
r
"\(\s*([^\)]*?)\s*\)"
,
r
"(\1)"
,
string
)
string
=
re
.
sub
(
r
"\[\s*([^\]]*?)\s*\]"
,
r
"[\1]"
,
string
)
string
=
re
.
sub
(
r
"{\s*([^}]*?)\s*}"
,
r
"{\1}"
,
string
)
string
=
re
.
sub
(
r
"\"\s*([^\"]*?)\s*\""
,
r
'"\1"'
,
string
)
string
=
re
.
sub
(
r
"'\s*([^']*?)\s*'"
,
r
"'\1'"
,
string
)
# miscellaneous
string
=
string
.
replace
(
"= = = ="
,
"===="
)
string
=
string
.
replace
(
"= = ="
,
"==="
)
string
=
string
.
replace
(
"= ="
,
"=="
)
string
=
string
.
replace
(
" "
+
chr
(
176
)
+
" "
,
chr
(
176
))
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" N "
,
" 1 "
)
string
=
string
.
replace
(
" 's"
,
"'s"
)
#
contractions
string
=
string
.
replace
(
"s '"
,
"s'"
)
string
=
re
.
sub
(
r
"/' [0-9]/"
,
r
"/'[0-9]/"
,
string
)
# number separators
string
=
string
.
replace
(
" @-@ "
,
"-"
)
string
=
string
.
replace
(
" @,@ "
,
","
)
string
=
string
.
replace
(
" @.@ "
,
"."
)
#
punctuation
string
=
string
.
replace
(
" : "
,
": "
)
string
=
string
.
replace
(
" ; "
,
"; "
)
string
=
string
.
replace
(
" . "
,
". "
)
string
=
string
.
replace
(
" ! "
,
"! "
)
string
=
string
.
replace
(
" ? "
,
"? "
)
string
=
string
.
replace
(
" , "
,
", "
)
# double brackets
string
=
re
.
sub
(
r
"\(\s*([^\)]*?)\s*\)"
,
r
"(\1)"
,
string
)
string
=
re
.
sub
(
r
"\[\s*([^\]]*?)\s*\]"
,
r
"[\1]"
,
string
)
string
=
re
.
sub
(
r
"{\s*([^}]*?)\s*}"
,
r
"{\1}"
,
string
)
string
=
re
.
sub
(
r
"\"\s*([^\"]*?)\s*\""
,
r
'"\1"'
,
string
)
string
=
re
.
sub
(
r
"'\s*([^']*?)\s*'"
,
r
"'\1'"
,
string
)
# miscellaneous
string
=
string
.
replace
(
"= = = ="
,
"===="
)
string
=
string
.
replace
(
"= = ="
,
"==="
)
string
=
string
.
replace
(
"= ="
,
"=="
)
string
=
string
.
replace
(
" "
+
chr
(
176
)
+
" "
,
chr
(
176
))
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
"
\n
"
,
"
\n
"
)
string
=
string
.
replace
(
" N "
,
" 1 "
)
string
=
string
.
replace
(
" 's"
,
"'s"
)
return
string
return
string
def
lambada_detokenizer
(
string
):
return
string
return
string
_DETOKENIZERS
=
{
'ptb'
:
ptb_detokenizer
,
'wikitext'
:
wikitext_detokenizer
,
'lambada'
:
lambada_detokenizer
,
'ptb'
:
ptb_detokenizer
,
'wikitext'
:
wikitext_detokenizer
,
'lambada'
:
lambada_detokenizer
,
}
def
get_detokenizer
(
path
):
for
key
in
_DETOKENIZERS
.
keys
():
if
key
in
path
:
return
_DETOKENIZERS
[
key
]
for
key
in
_DETOKENIZERS
.
keys
():
if
key
in
path
:
return
_DETOKENIZERS
[
key
]
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment