Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
67ee6d1f
Commit
67ee6d1f
authored
Jul 18, 2018
by
alexeib
Committed by
Myle Ott
Jul 25, 2018
Browse files
remove right-to-left lm support
parent
d2e2a1d4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
18 deletions
+3
-18
fairseq/data/token_block_dataset.py
fairseq/data/token_block_dataset.py
+2
-13
fairseq/tasks/language_modeling.py
fairseq/tasks/language_modeling.py
+1
-5
No files found.
fairseq/data/token_block_dataset.py
View file @
67ee6d1f
...
@@ -29,13 +29,12 @@ class TokenBlockDataset(torch.utils.data.Dataset):
...
@@ -29,13 +29,12 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets
include_targets: return next tokens as targets
"""
"""
def
__init__
(
self
,
tokens
,
sizes
,
block_size
,
break_mode
=
None
,
include_targets
=
False
,
reverse
=
False
):
def
__init__
(
self
,
tokens
,
sizes
,
block_size
,
break_mode
=
None
,
include_targets
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
tokens
=
tokens
self
.
tokens
=
tokens
self
.
total_size
=
len
(
tokens
)
self
.
total_size
=
len
(
tokens
)
self
.
include_targets
=
include_targets
self
.
include_targets
=
include_targets
self
.
reverse
=
reverse
self
.
slice_indices
=
[]
self
.
slice_indices
=
[]
if
break_mode
is
None
or
break_mode
==
'none'
:
if
break_mode
is
None
or
break_mode
==
'none'
:
...
@@ -78,19 +77,9 @@ class TokenBlockDataset(torch.utils.data.Dataset):
...
@@ -78,19 +77,9 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
s
,
e
=
self
.
slice_indices
[
index
]
s
,
e
=
self
.
slice_indices
[
index
]
if
self
.
reverse
:
item
=
torch
.
LongTensor
(
self
.
tokens
[
s
:
e
])
item
=
torch
.
LongTensor
(
np
.
flip
(
self
.
tokens
[
s
:
e
],
0
).
copy
())
else
:
item
=
torch
.
LongTensor
(
self
.
tokens
[
s
:
e
])
if
self
.
include_targets
:
if
self
.
include_targets
:
if
self
.
reverse
:
if
s
==
0
:
target
=
np
.
concatenate
([
self
.
tokens
[
-
1
:],
item
.
numpy
()[
1
:]])
else
:
target
=
np
.
concatenate
([
self
.
tokens
[
s
-
1
:
s
],
item
.
numpy
()[:
-
1
]])
return
item
,
torch
.
LongTensor
(
target
)
# target is the sentence, for source, rotate item one token to the left (would start with eos)
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if
s
==
0
:
if
s
==
0
:
source
=
np
.
concatenate
([
self
.
tokens
[
-
1
:],
self
.
tokens
[
0
:
e
-
1
]])
source
=
np
.
concatenate
([
self
.
tokens
[
-
1
:],
self
.
tokens
[
0
:
e
-
1
]])
...
...
fairseq/tasks/language_modeling.py
View file @
67ee6d1f
...
@@ -36,13 +36,9 @@ class LanguageModelingTask(FairseqTask):
...
@@ -36,13 +36,9 @@ class LanguageModelingTask(FairseqTask):
help
=
'max number of tokens per sample for LM dataset'
)
help
=
'max number of tokens per sample for LM dataset'
)
parser
.
add_argument
(
'--raw-text'
,
default
=
False
,
action
=
'store_true'
,
parser
.
add_argument
(
'--raw-text'
,
default
=
False
,
action
=
'store_true'
,
help
=
'load raw text dataset'
)
help
=
'load raw text dataset'
)
parser
.
add_argument
(
'--right-to-left'
,
default
=
False
,
action
=
'store_true'
,
help
=
'if set, trains a language model right-to-left (instead of left-to-right)'
)
def
__init__
(
self
,
args
,
dictionary
):
def
__init__
(
self
,
args
,
dictionary
):
super
().
__init__
(
args
)
super
().
__init__
(
args
)
args
.
right_to_left
=
getattr
(
args
,
'right_to_left'
,
False
)
self
.
dictionary
=
dictionary
self
.
dictionary
=
dictionary
@
classmethod
@
classmethod
...
@@ -75,7 +71,7 @@ class LanguageModelingTask(FairseqTask):
...
@@ -75,7 +71,7 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets
.
append
(
loaded_datasets
.
append
(
TokenBlockDataset
(
TokenBlockDataset
(
tokens
,
ds
.
sizes
,
self
.
args
.
tokens_per_sample
,
self
.
args
.
sample_break_mode
,
tokens
,
ds
.
sizes
,
self
.
args
.
tokens_per_sample
,
self
.
args
.
sample_break_mode
,
include_targets
=
True
,
reverse
=
self
.
args
.
right_to_left
,
include_targets
=
True
))
))
print
(
'| {} {} {} examples'
.
format
(
self
.
args
.
data
,
split_k
,
len
(
loaded_datasets
[
-
1
])))
print
(
'| {} {} {} examples'
.
format
(
self
.
args
.
data
,
split_k
,
len
(
loaded_datasets
[
-
1
])))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment