Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
ec0031df
"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0368483b61ad690aef5b0d92611270868c0ce8ea"
Unverified
Commit
ec0031df
authored
May 24, 2018
by
Myle Ott
Committed by
GitHub
May 24, 2018
Browse files
Merge internal changes (#163)
parent
29153e27
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
155 additions
and
57 deletions
+155
-57
fairseq/bleu.py
fairseq/bleu.py
+3
-2
fairseq/dictionary.py
fairseq/dictionary.py
+21
-4
fairseq/models/lstm.py
fairseq/models/lstm.py
+122
-45
fairseq/utils.py
fairseq/utils.py
+8
-3
tests/test_binaries.py
tests/test_binaries.py
+1
-0
tests/test_utils.py
tests/test_utils.py
+0
-3
No files found.
fairseq/bleu.py
View file @
ec0031df
...
...
@@ -57,9 +57,10 @@ class Scorer(object):
raise
TypeError
(
'pred must be a torch.IntTensor(got {})'
.
format
(
type
(
pred
)))
assert
self
.
unk
>
0
,
'unknown token index must be >0'
# don't match unknown words
rref
=
ref
.
clone
()
rref
.
apply_
(
lambda
x
:
x
if
x
!=
self
.
unk
else
-
x
)
assert
not
rref
.
lt
(
0
).
any
()
rref
[
rref
.
eq
(
self
.
unk
)]
=
-
999
rref
=
rref
.
contiguous
().
view
(
-
1
)
pred
=
pred
.
contiguous
().
view
(
-
1
)
...
...
fairseq/dictionary.py
View file @
ec0031df
...
...
@@ -81,6 +81,19 @@ class Dictionary(object):
self
.
count
.
append
(
n
)
return
idx
def
update
(
self
,
new_dict
):
"""Updates counts from new dictionary."""
for
word
in
new_dict
.
symbols
:
idx2
=
new_dict
.
indices
[
word
]
if
word
in
self
.
indices
:
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
new_dict
.
count
[
idx2
]
else
:
idx
=
len
(
self
.
symbols
)
self
.
indices
[
word
]
=
idx
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
new_dict
.
count
[
idx2
])
def
finalize
(
self
):
"""Sort symbols by frequency in descending order, ignoring special ones."""
self
.
count
,
self
.
symbols
=
zip
(
...
...
@@ -102,7 +115,7 @@ class Dictionary(object):
return
self
.
unk_index
@
classmethod
def
load
(
cls
,
f
):
def
load
(
cls
,
f
,
ignore_utf_errors
=
False
):
"""Loads the dictionary from a text file with the format:
```
...
...
@@ -114,8 +127,12 @@ class Dictionary(object):
if
isinstance
(
f
,
str
):
try
:
if
not
ignore_utf_errors
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
return
cls
.
load
(
fd
)
else
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
,
errors
=
'ignore'
)
as
fd
:
return
cls
.
load
(
fd
)
except
FileNotFoundError
as
fnfe
:
raise
fnfe
except
Exception
:
...
...
@@ -141,6 +158,6 @@ class Dictionary(object):
cnt
=
0
for
i
,
t
in
enumerate
(
zip
(
self
.
symbols
,
self
.
count
)):
if
i
>=
self
.
nspecial
and
t
[
1
]
>=
threshold
\
and
(
nwords
<
0
or
cnt
<
nwords
):
and
(
nwords
<
=
0
or
cnt
<
nwords
):
print
(
'{} {}'
.
format
(
t
[
0
],
t
[
1
]),
file
=
f
)
cnt
+=
1
fairseq/models/lstm.py
View file @
ec0031df
...
...
@@ -30,12 +30,18 @@ class LSTMModel(FairseqModel):
help
=
'encoder embedding dimension'
)
parser
.
add_argument
(
'--encoder-embed-path'
,
default
=
None
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained encoder embedding'
)
parser
.
add_argument
(
'--encoder-hidden-size'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder hidden size'
)
parser
.
add_argument
(
'--encoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of encoder layers'
)
parser
.
add_argument
(
'--encoder-bidirectional'
,
action
=
'store_true'
,
help
=
'make all layers of encoder bidirectional'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-embed-path'
,
default
=
None
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained decoder embedding'
)
parser
.
add_argument
(
'--decoder-hidden-size'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder hidden size'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of decoder layers'
)
parser
.
add_argument
(
'--decoder-out-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
...
...
@@ -60,68 +66,102 @@ class LSTMModel(FairseqModel):
args
.
encoder_embed_path
=
None
if
not
hasattr
(
args
,
'decoder_embed_path'
):
args
.
decoder_embed_path
=
None
if
not
hasattr
(
args
,
'encoder_hidden_size'
):
args
.
encoder_hidden_size
=
args
.
encoder_embed_dim
if
not
hasattr
(
args
,
'decoder_hidden_size'
):
args
.
decoder_hidden_size
=
args
.
decoder_embed_dim
if
not
hasattr
(
args
,
'encoder_bidirectional'
):
args
.
encoder_bidirectional
=
False
def
load_pretrained_embedding_from_file
(
embed_path
,
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
embed_dict
=
utils
.
parse_embedding
(
embed_path
)
utils
.
print_embed_overlap
(
embed_dict
,
dictionary
)
return
utils
.
load_embedding
(
embed_dict
,
dictionary
,
embed_tokens
)
encoder_embed
_dict
=
None
pretrained_
encoder_embed
=
None
if
args
.
encoder_embed_path
:
encoder_embed_dict
=
utils
.
parse_embedding
(
args
.
encoder_embed_path
)
utils
.
print_embed_overlap
(
encoder_embed_dict
,
src_dict
)
decoder_embed_dict
=
None
pretrained_encoder_embed
=
load_pretrained_embedding_from_file
(
args
.
encoder_embed_path
,
src_dict
,
args
.
encoder_embed_dim
)
pretrained_decoder_embed
=
None
if
args
.
decoder_embed_path
:
decoder_embed
_dict
=
utils
.
parse_embedding
(
args
.
decoder_embed_path
)
utils
.
print_embed_overlap
(
decoder_embed_di
ct
,
dst_dict
)
pretrained_
decoder_embed
=
load_pretrained_embedding_from_file
(
args
.
decoder_embed_path
,
dst_dict
,
args
.
decoder_embed_di
m
)
encoder
=
LSTMEncoder
(
src_dict
,
dictionary
=
src_dict
,
embed_dim
=
args
.
encoder_embed_dim
,
embed_dict
=
encoder_embed_dict
,
hidden_size
=
args
.
encoder_hidden_size
,
num_layers
=
args
.
encoder_layers
,
dropout_in
=
args
.
encoder_dropout_in
,
dropout_out
=
args
.
encoder_dropout_out
,
bidirectional
=
args
.
encoder_bidirectional
,
pretrained_embed
=
pretrained_encoder_embed
,
)
try
:
attention
=
bool
(
eval
(
args
.
decoder_attention
))
except
TypeError
:
attention
=
bool
(
args
.
decoder_attention
)
decoder
=
LSTMDecoder
(
dst_dict
,
encoder_embed_dim
=
args
.
encoder_embed_dim
,
dictionary
=
dst_dict
,
embed_dim
=
args
.
decoder_embed_dim
,
embed_dict
=
decoder_embed_dict
,
hidden_size
=
args
.
decoder_hidden_size
,
out_embed_dim
=
args
.
decoder_out_embed_dim
,
num_layers
=
args
.
decoder_layers
,
attention
=
bool
(
eval
(
args
.
decoder_attention
)),
dropout_in
=
args
.
decoder_dropout_in
,
dropout_out
=
args
.
decoder_dropout_out
,
attention
=
attention
,
encoder_embed_dim
=
args
.
encoder_embed_dim
,
encoder_output_units
=
encoder
.
output_units
,
pretrained_embed
=
pretrained_decoder_embed
,
)
return
cls
(
encoder
,
decoder
)
class
LSTMEncoder
(
FairseqEncoder
):
"""LSTM encoder."""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
embed_dict
=
None
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
):
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
hidden_size
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
,
bidirectional
=
False
,
left_pad_source
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
,
pretrained_embed
=
None
,
padding_value
=
0.
,
):
super
().
__init__
(
dictionary
)
self
.
num_layers
=
num_layers
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
self
.
bidirectional
=
bidirectional
self
.
hidden_size
=
hidden_size
num_embeddings
=
len
(
dictionary
)
self
.
padding_idx
=
dictionary
.
pad
()
if
pretrained_embed
is
None
:
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
self
.
padding_idx
)
if
embed_dict
:
self
.
embed_tokens
=
utils
.
load_embedding
(
embed_dict
,
self
.
dictionary
,
self
.
embed_tokens
)
else
:
self
.
embed_tokens
=
pretrained_embed
self
.
lstm
=
LSTM
(
input_size
=
embed_dim
,
hidden_size
=
embed_dim
,
hidden_size
=
hidden_size
,
num_layers
=
num_layers
,
dropout
=
self
.
dropout_out
,
bidirectional
=
False
,
bidirectional
=
bidirectional
,
)
self
.
left_pad_source
=
left_pad_source
self
.
padding_value
=
padding_value
self
.
output_units
=
hidden_size
if
bidirectional
:
self
.
output_units
*=
2
def
forward
(
self
,
src_tokens
,
src_lengths
):
if
LanguagePairDataset
.
LEFT_PAD_SOURCE
:
if
self
.
left_pad_source
:
# convert left-padding to right-padding
src_tokens
=
utils
.
convert_padding_direction
(
src_tokens
,
src_lengths
,
self
.
padding_idx
,
left_to_right
=
True
,
)
...
...
@@ -131,7 +171,6 @@ class LSTMEncoder(FairseqEncoder):
# embed tokens
x
=
self
.
embed_tokens
(
src_tokens
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout_in
,
training
=
self
.
training
)
embed_dim
=
x
.
size
(
2
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
...
...
@@ -140,17 +179,35 @@ class LSTMEncoder(FairseqEncoder):
packed_x
=
nn
.
utils
.
rnn
.
pack_padded_sequence
(
x
,
src_lengths
.
data
.
tolist
())
# apply LSTM
h0
=
Variable
(
x
.
data
.
new
(
self
.
num_layers
,
bsz
,
embed_dim
).
zero_
())
c0
=
Variable
(
x
.
data
.
new
(
self
.
num_layers
,
bsz
,
embed_dim
).
zero_
())
if
self
.
bidirectional
:
state_size
=
2
*
self
.
num_layers
,
bsz
,
self
.
hidden_size
else
:
state_size
=
self
.
num_layers
,
bsz
,
self
.
hidden_size
h0
=
Variable
(
x
.
data
.
new
(
*
state_size
).
zero_
())
c0
=
Variable
(
x
.
data
.
new
(
*
state_size
).
zero_
())
packed_outs
,
(
final_hiddens
,
final_cells
)
=
self
.
lstm
(
packed_x
,
(
h0
,
c0
),
)
# unpack outputs and apply dropout
x
,
_
=
nn
.
utils
.
rnn
.
pad_packed_sequence
(
packed_outs
,
padding_value
=
0.
)
x
,
_
=
nn
.
utils
.
rnn
.
pad_packed_sequence
(
packed_outs
,
padding_value
=
self
.
padding_value
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout_out
,
training
=
self
.
training
)
assert
list
(
x
.
size
())
==
[
seqlen
,
bsz
,
embed_dim
]
assert
list
(
x
.
size
())
==
[
seqlen
,
bsz
,
self
.
output_units
]
if
self
.
bidirectional
:
bi_final_hiddens
,
bi_final_cells
=
[],
[]
for
i
in
range
(
self
.
num_layers
):
bi_final_hiddens
.
append
(
torch
.
cat
(
(
final_hiddens
[
2
*
i
],
final_hiddens
[
2
*
i
+
1
]),
dim
=
0
).
view
(
bsz
,
self
.
output_units
))
bi_final_cells
.
append
(
torch
.
cat
(
(
final_cells
[
2
*
i
],
final_cells
[
2
*
i
+
1
]),
dim
=
0
).
view
(
bsz
,
self
.
output_units
))
return
x
,
bi_final_hiddens
,
bi_final_cells
return
x
,
final_hiddens
,
final_cells
...
...
@@ -166,7 +223,7 @@ class AttentionLayer(nn.Module):
self
.
input_proj
=
Linear
(
input_embed_dim
,
output_embed_dim
,
bias
=
False
)
self
.
output_proj
=
Linear
(
2
*
output_embed_dim
,
output_embed_dim
,
bias
=
False
)
def
forward
(
self
,
input
,
source_hids
):
def
forward
(
self
,
input
,
source_hids
,
src_lengths
=
None
):
# input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim
...
...
@@ -186,27 +243,39 @@ class AttentionLayer(nn.Module):
class
LSTMDecoder
(
FairseqIncrementalDecoder
):
"""LSTM decoder."""
def
__init__
(
self
,
dictionary
,
encoder_embed_dim
=
512
,
embed_dim
=
512
,
embed_dict
=
None
,
out_embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
,
attention
=
True
):
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
hidden_size
=
512
,
out_embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
,
attention
=
True
,
encoder_embed_dim
=
512
,
encoder_output_units
=
512
,
pretrained_embed
=
None
,
):
super
().
__init__
(
dictionary
)
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
self
.
hidden_size
=
hidden_size
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
if
pretrained_embed
is
None
:
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
if
embed_dict
:
self
.
embed_tokens
=
utils
.
load_embedding
(
embed_dict
,
self
.
dictionary
,
self
.
embed_tokens
)
else
:
self
.
embed_tokens
=
pretrained_embed
self
.
encoder_output_units
=
encoder_output_units
assert
encoder_output_units
==
hidden_size
,
\
'{} {}'
.
format
(
encoder_output_units
,
hidden_size
)
# TODO another Linear layer if not equal
self
.
layers
=
nn
.
ModuleList
([
LSTMCell
(
encoder_embed_dim
+
embed_dim
if
layer
==
0
else
embed_dim
,
embed_dim
)
LSTMCell
(
input_size
=
encoder_output_units
+
embed_dim
if
layer
==
0
else
hidden_size
,
hidden_size
=
hidden_size
,
)
for
layer
in
range
(
num_layers
)
])
self
.
attention
=
AttentionLayer
(
encoder_
embed_dim
,
embed_dim
)
if
attention
else
None
if
embed_dim
!=
out_embed_dim
:
self
.
additional_fc
=
Linear
(
embed_dim
,
out_embed_dim
)
self
.
attention
=
AttentionLayer
(
encoder_
output_units
,
hidden_size
)
if
attention
else
None
if
hidden_size
!=
out_embed_dim
:
self
.
additional_fc
=
Linear
(
hidden_size
,
out_embed_dim
)
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
...
...
@@ -215,13 +284,12 @@ class LSTMDecoder(FairseqIncrementalDecoder):
bsz
,
seqlen
=
prev_output_tokens
.
size
()
# get outputs from encoder
encoder_outs
,
_
,
_
=
encoder_out
encoder_outs
,
_
,
_
=
encoder_out
[:
3
]
srclen
=
encoder_outs
.
size
(
0
)
# embed tokens
x
=
self
.
embed_tokens
(
prev_output_tokens
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout_in
,
training
=
self
.
training
)
embed_dim
=
x
.
size
(
2
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
...
...
@@ -231,11 +299,11 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if
cached_state
is
not
None
:
prev_hiddens
,
prev_cells
,
input_feed
=
cached_state
else
:
_
,
encoder_hiddens
,
encoder_cells
=
encoder_out
_
,
encoder_hiddens
,
encoder_cells
=
encoder_out
[:
3
]
num_layers
=
len
(
self
.
layers
)
prev_hiddens
=
[
encoder_hiddens
[
i
]
for
i
in
range
(
num_layers
)]
prev_cells
=
[
encoder_cells
[
i
]
for
i
in
range
(
num_layers
)]
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
embed_dim
).
zero_
())
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
self
.
encoder_output_units
).
zero_
())
attn_scores
=
Variable
(
x
.
data
.
new
(
srclen
,
seqlen
,
bsz
).
zero_
())
outs
=
[]
...
...
@@ -272,7 +340,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self
,
incremental_state
,
'cached_state'
,
(
prev_hiddens
,
prev_cells
,
input_feed
))
# collect outputs across time steps
x
=
torch
.
cat
(
outs
,
dim
=
0
).
view
(
seqlen
,
bsz
,
embed_dim
)
x
=
torch
.
cat
(
outs
,
dim
=
0
).
view
(
seqlen
,
bsz
,
self
.
hidden_size
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
1
,
0
)
...
...
@@ -342,10 +410,13 @@ def Linear(in_features, out_features, bias=True, dropout=0):
@
register_model_architecture
(
'lstm'
,
'lstm'
)
def
base_architecture
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_hidden_size
=
getattr
(
args
,
'encoder_hidden_size'
,
512
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
1
)
args
.
encoder_bidirectional
=
getattr
(
args
,
'encoder_bidirectional'
,
False
)
args
.
encoder_dropout_in
=
getattr
(
args
,
'encoder_dropout_in'
,
args
.
dropout
)
args
.
encoder_dropout_out
=
getattr
(
args
,
'encoder_dropout_out'
,
args
.
dropout
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
512
)
args
.
decoder_hidden_size
=
getattr
(
args
,
'decoder_hidden_size'
,
512
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
1
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
512
)
args
.
decoder_attention
=
getattr
(
args
,
'decoder_attention'
,
True
)
...
...
@@ -357,10 +428,13 @@ def base_architecture(args):
def
lstm_wiseman_iwslt_de_en
(
args
):
base_architecture
(
args
)
args
.
encoder_embed_dim
=
256
args
.
encoder_hidden_size
=
256
args
.
encoder_layers
=
1
args
.
encoder_bidirectional
=
False
args
.
encoder_dropout_in
=
0
args
.
encoder_dropout_out
=
0
args
.
decoder_embed_dim
=
256
args
.
decoder_hidden_size
=
256
args
.
decoder_layers
=
1
args
.
decoder_out_embed_dim
=
256
args
.
decoder_attention
=
True
...
...
@@ -371,9 +445,12 @@ def lstm_wiseman_iwslt_de_en(args):
def
lstm_luong_wmt_en_de
(
args
):
base_architecture
(
args
)
args
.
encoder_embed_dim
=
1000
args
.
encoder_hidden_size
=
1000
args
.
encoder_layers
=
4
args
.
encoder_dropout_out
=
0
args
.
encoder_bidirectional
=
False
args
.
decoder_embed_dim
=
1000
args
.
decoder_hidden_size
=
1000
args
.
decoder_layers
=
4
args
.
decoder_out_embed_dim
=
1000
args
.
decoder_attention
=
True
...
...
fairseq/utils.py
View file @
ec0031df
...
...
@@ -266,7 +266,7 @@ def parse_embedding(embed_path):
the -0.0230 -0.0264 0.0287 0.0171 0.1403
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
"""
embed_dict
=
dict
()
embed_dict
=
{}
with
open
(
embed_path
)
as
f_embed
:
_
=
next
(
f_embed
)
# skip header
for
line
in
f_embed
:
...
...
@@ -344,16 +344,21 @@ def buffered_arange(max):
def
convert_padding_direction
(
src_tokens
,
src_lengths
,
padding_idx
,
right_to_left
=
False
,
left_to_right
=
False
,
):
assert
right_to_left
^
left_to_right
pad_mask
=
src_tokens
.
eq
(
padding_idx
)
if
pad_mask
.
max
()
==
0
:
if
not
pad_mask
.
any
()
:
# no padding, return early
return
src_tokens
if
left_to_right
and
not
pad_mask
[:,
0
].
any
():
# already right padded
return
src_tokens
if
right_to_left
and
not
pad_mask
[:,
-
1
].
any
():
# already left padded
return
src_tokens
max_len
=
src_tokens
.
size
(
1
)
range
=
buffered_arange
(
max_len
).
type_as
(
src_tokens
).
expand_as
(
src_tokens
)
num_pads
=
pad_mask
.
long
().
sum
(
dim
=
1
,
keepdim
=
True
)
...
...
tests/test_binaries.py
View file @
ec0031df
...
...
@@ -103,6 +103,7 @@ class TestBinaries(unittest.TestCase):
generate
.
main
(
generate_args
)
# evaluate model interactively
generate_args
.
max_sentences
=
None
orig_stdin
=
sys
.
stdin
sys
.
stdin
=
StringIO
(
'h e l l o
\n
'
)
interactive
.
main
(
generate_args
)
...
...
tests/test_utils.py
View file @
ec0031df
...
...
@@ -27,13 +27,11 @@ class TestUtils(unittest.TestCase):
[
7
,
8
,
9
,
10
,
1
],
[
11
,
12
,
1
,
1
,
1
],
])
lengths
=
torch
.
LongTensor
([
5
,
4
,
2
])
self
.
assertAlmostEqual
(
right_pad
,
utils
.
convert_padding_direction
(
left_pad
,
lengths
,
pad
,
left_to_right
=
True
,
),
...
...
@@ -42,7 +40,6 @@ class TestUtils(unittest.TestCase):
left_pad
,
utils
.
convert_padding_direction
(
right_pad
,
lengths
,
pad
,
right_to_left
=
True
,
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment