Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
847e7f33
Unverified
Commit
847e7f33
authored
Apr 28, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 28, 2020
Browse files
MarianMTModel.from_pretrained('Helsinki-NLP/opus-marian-en-de') (#3908)
Co-Authored-By:
Stefan Schweter
<
stefan@schweter.it
>
parent
d714dfea
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
887 additions
and
26 deletions
+887
-26
setup.cfg
setup.cfg
+1
-0
src/transformers/__init__.py
src/transformers/__init__.py
+3
-0
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+12
-2
src/transformers/configuration_marian.py
src/transformers/configuration_marian.py
+26
-0
src/transformers/convert_marian_to_pytorch.py
src/transformers/convert_marian_to_pytorch.py
+397
-0
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+89
-11
src/transformers/modeling_marian.py
src/transformers/modeling_marian.py
+35
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+0
-12
src/transformers/tokenization_marian.py
src/transformers/tokenization_marian.py
+160
-0
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+7
-1
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+39
-0
tests/test_modeling_marian.py
tests/test_modeling_marian.py
+118
-0
No files found.
setup.cfg
View file @
847e7f33
...
...
@@ -27,6 +27,7 @@ known_third_party =
torchtext
torchvision
torch_xla
tqdm
line_length = 119
lines_after_imports = 2
...
...
src/transformers/__init__.py
View file @
847e7f33
...
...
@@ -44,6 +44,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr
from
.configuration_encoder_decoder
import
EncoderDecoderConfig
from
.configuration_flaubert
import
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
FlaubertConfig
from
.configuration_gpt2
import
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
from
.configuration_marian
import
MarianConfig
from
.configuration_mmbt
import
MMBTConfig
from
.configuration_openai
import
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
from
.configuration_roberta
import
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
...
...
@@ -241,6 +242,8 @@ if is_torch_available():
BartForConditionalGeneration
,
BART_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
from
.modeling_marian
import
MarianMTModel
from
.tokenization_marian
import
MarianSentencePieceTokenizer
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaModel
,
...
...
src/transformers/configuration_bart.py
View file @
847e7f33
...
...
@@ -65,6 +65,9 @@ class BartConfig(PretrainedConfig):
normalize_before
=
False
,
add_final_layer_norm
=
False
,
scale_embedding
=
False
,
normalize_embedding
=
True
,
static_position_embeddings
=
False
,
add_bias_logits
=
False
,
**
common_kwargs
):
r
"""
...
...
@@ -73,6 +76,8 @@ class BartConfig(PretrainedConfig):
config = BartConfig.from_pretrained('bart-large')
model = BartModel(config)
"""
if
"hidden_size"
in
common_kwargs
:
raise
ValueError
(
"hidden size is called d_model"
)
super
().
__init__
(
num_labels
=
num_labels
,
pad_token_id
=
pad_token_id
,
...
...
@@ -94,12 +99,17 @@ class BartConfig(PretrainedConfig):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
init_std
=
init_std
# Normal(0, this parameter)
self
.
activation_function
=
activation_function
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
# True for mbart, False otherwise
# Params introduced for Mbart
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
self
.
normalize_embedding
=
normalize_embedding
# True for mbart, False otherwise
self
.
normalize_before
=
normalize_before
# combo of fairseq's encoder_ and decoder_normalize_before
self
.
add_final_layer_norm
=
add_final_layer_norm
# Params introduced for Marian
self
.
add_bias_logits
=
add_bias_logits
self
.
static_position_embeddings
=
static_position_embeddings
# 3 Types of Dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_dropout
=
activation_dropout
...
...
src/transformers/configuration_marian.py
0 → 100644
View file @
847e7f33
# coding=utf-8
# Copyright 2020 The OPUS-NMT Team, Marian team, and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Marian model configuration """
from
.configuration_bart
import
BartConfig
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"marian-en-de"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/config.json"
,
}
class
MarianConfig
(
BartConfig
):
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
src/transformers/convert_marian_to_pytorch.py
0 → 100644
View file @
847e7f33
import
argparse
import
json
import
os
import
shutil
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Union
from
zipfile
import
ZipFile
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
transformers
import
MarianConfig
,
MarianMTModel
,
MarianSentencePieceTokenizer
def
remove_prefix
(
text
:
str
,
prefix
:
str
):
if
text
.
startswith
(
prefix
):
return
text
[
len
(
prefix
)
:]
return
text
# or whatever
def
convert_encoder_layer
(
opus_dict
,
layer_prefix
:
str
,
converter
:
dict
):
sd
=
{}
for
k
in
opus_dict
:
if
not
k
.
startswith
(
layer_prefix
):
continue
stripped
=
remove_prefix
(
k
,
layer_prefix
)
v
=
opus_dict
[
k
].
T
# besides embeddings, everything must be transposed.
sd
[
converter
[
stripped
]]
=
torch
.
tensor
(
v
).
squeeze
()
return
sd
def
load_layers_
(
layer_lst
:
torch
.
nn
.
ModuleList
,
opus_state
:
dict
,
converter
,
is_decoder
=
False
):
for
i
,
layer
in
enumerate
(
layer_lst
):
layer_tag
=
f
"decoder_l
{
i
+
1
}
_"
if
is_decoder
else
f
"encoder_l
{
i
+
1
}
_"
sd
=
convert_encoder_layer
(
opus_state
,
layer_tag
,
converter
)
layer
.
load_state_dict
(
sd
,
strict
=
True
)
def
add_emb_entries
(
wemb
,
final_bias
,
n_special_tokens
=
1
):
vsize
,
d_model
=
wemb
.
shape
embs_to_add
=
np
.
zeros
((
n_special_tokens
,
d_model
))
new_embs
=
np
.
concatenate
([
wemb
,
embs_to_add
])
bias_to_add
=
np
.
zeros
((
n_special_tokens
,
1
))
new_bias
=
np
.
concatenate
((
final_bias
,
bias_to_add
),
axis
=
1
)
return
new_embs
,
new_bias
def
_cast_yaml_str
(
v
):
bool_dct
=
{
"true"
:
True
,
"false"
:
False
}
if
not
isinstance
(
v
,
str
):
return
v
elif
v
in
bool_dct
:
return
bool_dct
[
v
]
try
:
return
int
(
v
)
except
(
TypeError
,
ValueError
):
return
v
def
cast_marian_config
(
raw_cfg
:
Dict
[
str
,
str
])
->
Dict
:
return
{
k
:
_cast_yaml_str
(
v
)
for
k
,
v
in
raw_cfg
.
items
()}
CONFIG_KEY
=
"special:model.yml"
def
load_config_from_state_dict
(
opus_dict
):
import
yaml
cfg_str
=
""
.
join
([
chr
(
x
)
for
x
in
opus_dict
[
CONFIG_KEY
]])
yaml_cfg
=
yaml
.
load
(
cfg_str
[:
-
1
],
Loader
=
yaml
.
BaseLoader
)
return
cast_marian_config
(
yaml_cfg
)
def
find_model_file
(
dest_dir
):
# this one better
model_files
=
list
(
Path
(
dest_dir
).
glob
(
"*.npz"
))
assert
len
(
model_files
)
==
1
,
model_files
model_file
=
model_files
[
0
]
return
model_file
def
parse_readmes
(
repo_path
):
results
=
{}
for
p
in
Path
(
repo_path
).
ls
():
n_dash
=
p
.
name
.
count
(
"-"
)
if
n_dash
==
0
:
continue
else
:
lns
=
list
(
open
(
p
/
"README.md"
).
readlines
())
results
[
p
.
name
]
=
_parse_readme
(
lns
)
return
results
def
download_all_sentencepiece_models
(
repo_path
=
"Opus-MT-train/models"
):
"""Requires 300GB"""
save_dir
=
Path
(
"marian_ckpt"
)
if
not
Path
(
repo_path
).
exists
():
raise
ValueError
(
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git"
)
results
:
dict
=
parse_readmes
(
repo_path
)
for
k
,
v
in
tqdm
(
list
(
results
.
items
())):
if
os
.
path
.
exists
(
save_dir
/
k
):
print
(
f
"already have path
{
k
}
"
)
continue
if
"SentencePiece"
not
in
v
[
"pre-processing"
]:
continue
download_and_unzip
(
v
[
"download"
],
save_dir
/
k
)
def
_parse_readme
(
lns
):
"""Get link and metadata from opus model card equivalent."""
subres
=
{}
for
ln
in
[
x
.
strip
()
for
x
in
lns
]:
if
not
ln
.
startswith
(
"*"
):
continue
ln
=
ln
[
1
:].
strip
()
for
k
in
[
"download"
,
"dataset"
,
"models"
,
"model"
,
"pre-processing"
]:
if
ln
.
startswith
(
k
):
break
else
:
continue
if
k
in
[
"dataset"
,
"model"
,
"pre-processing"
]:
splat
=
ln
.
split
(
":"
)
_
,
v
=
splat
subres
[
k
]
=
v
elif
k
==
"download"
:
v
=
ln
.
split
(
"("
)[
-
1
][:
-
1
]
subres
[
k
]
=
v
return
subres
def
write_metadata
(
dest_dir
:
Path
):
dname
=
dest_dir
.
name
.
split
(
"-"
)
dct
=
dict
(
target_lang
=
dname
[
-
1
],
source_lang
=
"-"
.
join
(
dname
[:
-
1
]))
save_json
(
dct
,
dest_dir
/
"tokenizer_config.json"
)
def
add_to_vocab_
(
vocab
:
Dict
[
str
,
int
],
special_tokens
:
List
[
str
]):
start
=
max
(
vocab
.
values
())
+
1
added
=
0
for
tok
in
special_tokens
:
if
tok
in
vocab
:
continue
vocab
[
tok
]
=
start
+
added
added
+=
1
return
added
def
add_special_tokens_to_vocab
(
model_dir
:
Path
)
->
None
:
vocab
=
load_yaml
(
model_dir
/
"opus.spm32k-spm32k.vocab.yml"
)
vocab
=
{
k
:
int
(
v
)
for
k
,
v
in
vocab
.
items
()}
num_added
=
add_to_vocab_
(
vocab
,
[
"<pad>"
])
print
(
f
"added
{
num_added
}
tokens to vocab"
)
save_json
(
vocab
,
model_dir
/
"vocab.json"
)
write_metadata
(
model_dir
)
def
save_tokenizer
(
self
,
save_directory
):
dest
=
Path
(
save_directory
)
src_path
=
Path
(
self
.
init_kwargs
[
"source_spm"
])
for
dest_name
in
{
"source.spm"
,
"target.spm"
,
"tokenizer_config.json"
}:
shutil
.
copyfile
(
src_path
.
parent
/
dest_name
,
dest
/
dest_name
)
save_json
(
self
.
encoder
,
dest
/
"vocab.json"
)
def
check_equal
(
marian_cfg
,
k1
,
k2
):
v1
,
v2
=
marian_cfg
[
k1
],
marian_cfg
[
k2
]
assert
v1
==
v2
,
f
"hparams
{
k1
}
,
{
k2
}
differ:
{
v1
}
!=
{
v2
}
"
def
check_marian_cfg_assumptions
(
marian_cfg
):
assumed_settings
=
{
"tied-embeddings-all"
:
True
,
"layer-normalization"
:
False
,
"right-left"
:
False
,
"transformer-ffn-depth"
:
2
,
"transformer-aan-depth"
:
2
,
"transformer-no-projection"
:
False
,
"transformer-postprocess-emb"
:
"d"
,
"transformer-postprocess"
:
"dan"
,
# Dropout, add, normalize
"transformer-preprocess"
:
""
,
"type"
:
"transformer"
,
"ulr-dim-emb"
:
0
,
"dec-cell-base-depth"
:
2
,
"dec-cell-high-depth"
:
1
,
"transformer-aan-nogate"
:
False
,
}
for
k
,
v
in
assumed_settings
.
items
():
actual
=
marian_cfg
[
k
]
assert
actual
==
v
,
f
"Unexpected config value for
{
k
}
expected
{
v
}
got
{
actual
}
"
check_equal
(
marian_cfg
,
"transformer-ffn-activation"
,
"transformer-aan-activation"
)
check_equal
(
marian_cfg
,
"transformer-ffn-depth"
,
"transformer-aan-depth"
)
check_equal
(
marian_cfg
,
"transformer-dim-ffn"
,
"transformer-dim-aan"
)
BIAS_KEY
=
"decoder_ff_logit_out_b"
BART_CONVERTER
=
{
# for each encoder and decoder layer
"self_Wq"
:
"self_attn.q_proj.weight"
,
"self_Wk"
:
"self_attn.k_proj.weight"
,
"self_Wv"
:
"self_attn.v_proj.weight"
,
"self_Wo"
:
"self_attn.out_proj.weight"
,
"self_bq"
:
"self_attn.q_proj.bias"
,
"self_bk"
:
"self_attn.k_proj.bias"
,
"self_bv"
:
"self_attn.v_proj.bias"
,
"self_bo"
:
"self_attn.out_proj.bias"
,
"self_Wo_ln_scale"
:
"self_attn_layer_norm.weight"
,
"self_Wo_ln_bias"
:
"self_attn_layer_norm.bias"
,
"ffn_W1"
:
"fc1.weight"
,
"ffn_b1"
:
"fc1.bias"
,
"ffn_W2"
:
"fc2.weight"
,
"ffn_b2"
:
"fc2.bias"
,
"ffn_ffn_ln_scale"
:
"final_layer_norm.weight"
,
"ffn_ffn_ln_bias"
:
"final_layer_norm.bias"
,
# Decoder Cross Attention
"context_Wk"
:
"encoder_attn.k_proj.weight"
,
"context_Wo"
:
"encoder_attn.out_proj.weight"
,
"context_Wq"
:
"encoder_attn.q_proj.weight"
,
"context_Wv"
:
"encoder_attn.v_proj.weight"
,
"context_bk"
:
"encoder_attn.k_proj.bias"
,
"context_bo"
:
"encoder_attn.out_proj.bias"
,
"context_bq"
:
"encoder_attn.q_proj.bias"
,
"context_bv"
:
"encoder_attn.v_proj.bias"
,
"context_Wo_ln_scale"
:
"encoder_attn_layer_norm.weight"
,
"context_Wo_ln_bias"
:
"encoder_attn_layer_norm.bias"
,
}
class
OpusState
:
def
__init__
(
self
,
source_dir
):
npz_path
=
find_model_file
(
source_dir
)
self
.
state_dict
=
np
.
load
(
npz_path
)
cfg
=
load_config_from_state_dict
(
self
.
state_dict
)
assert
cfg
[
"dim-vocabs"
][
0
]
==
cfg
[
"dim-vocabs"
][
1
]
assert
"Wpos"
not
in
self
.
state_dict
self
.
state_dict
=
dict
(
self
.
state_dict
)
self
.
wemb
,
self
.
final_bias
=
add_emb_entries
(
self
.
state_dict
[
"Wemb"
],
self
.
state_dict
[
BIAS_KEY
],
1
)
self
.
pad_token_id
=
self
.
wemb
.
shape
[
0
]
-
1
cfg
[
"vocab_size"
]
=
self
.
pad_token_id
+
1
# self.state_dict['Wemb'].sha
self
.
state_keys
=
list
(
self
.
state_dict
.
keys
())
if
"Wtype"
in
self
.
state_dict
:
raise
ValueError
(
"found Wtype key"
)
self
.
_check_layer_entries
()
self
.
source_dir
=
source_dir
self
.
cfg
=
cfg
hidden_size
,
intermediate_shape
=
self
.
state_dict
[
"encoder_l1_ffn_W1"
].
shape
assert
hidden_size
==
cfg
[
"dim-emb"
]
==
512
# Process decoder.yml
decoder_yml
=
cast_marian_config
(
load_yaml
(
source_dir
/
"decoder.yml"
))
# TODO: what are normalize and word-penalty?
check_marian_cfg_assumptions
(
cfg
)
self
.
hf_config
=
MarianConfig
(
vocab_size
=
cfg
[
"vocab_size"
],
decoder_layers
=
cfg
[
"dec-depth"
],
encoder_layers
=
cfg
[
"enc-depth"
],
decoder_attention_heads
=
cfg
[
"transformer-heads"
],
encoder_attention_heads
=
cfg
[
"transformer-heads"
],
decoder_ffn_dim
=
cfg
[
"transformer-dim-ffn"
],
encoder_ffn_dim
=
cfg
[
"transformer-dim-ffn"
],
d_model
=
cfg
[
"dim-emb"
],
activation_function
=
cfg
[
"transformer-aan-activation"
],
pad_token_id
=
self
.
pad_token_id
,
eos_token_id
=
0
,
bos_token_id
=
0
,
max_position_embeddings
=
cfg
[
"dim-emb"
],
scale_embedding
=
True
,
normalize_embedding
=
"n"
in
cfg
[
"transformer-preprocess"
],
static_position_embeddings
=
not
cfg
[
"transformer-train-position-embeddings"
],
dropout
=
0.1
,
# see opus-mt-train repo/transformer-dropout param.
# default: add_final_layer_norm=False,
num_beams
=
decoder_yml
[
"beam-size"
],
)
def
_check_layer_entries
(
self
):
self
.
encoder_l1
=
self
.
sub_keys
(
"encoder_l1"
)
self
.
decoder_l1
=
self
.
sub_keys
(
"decoder_l1"
)
self
.
decoder_l2
=
self
.
sub_keys
(
"decoder_l2"
)
if
len
(
self
.
encoder_l1
)
!=
16
:
warnings
.
warn
(
f
"Expected 16 keys for each encoder layer, got
{
len
(
self
.
encoder_l1
)
}
"
)
if
len
(
self
.
decoder_l1
)
!=
26
:
warnings
.
warn
(
f
"Expected 26 keys for each decoder layer, got
{
len
(
self
.
decoder_l1
)
}
"
)
if
len
(
self
.
decoder_l2
)
!=
26
:
warnings
.
warn
(
f
"Expected 26 keys for each decoder layer, got
{
len
(
self
.
decoder_l1
)
}
"
)
@
property
def
extra_keys
(
self
):
extra
=
[]
for
k
in
self
.
state_keys
:
if
(
k
.
startswith
(
"encoder_l"
)
or
k
.
startswith
(
"decoder_l"
)
or
k
in
[
CONFIG_KEY
,
"Wemb"
,
"Wpos"
,
"decoder_ff_logit_out_b"
]
):
continue
else
:
extra
.
append
(
k
)
return
extra
def
sub_keys
(
self
,
layer_prefix
):
return
[
remove_prefix
(
k
,
layer_prefix
)
for
k
in
self
.
state_dict
if
k
.
startswith
(
layer_prefix
)]
def
load_marian_model
(
self
)
->
MarianMTModel
:
state_dict
,
cfg
=
self
.
state_dict
,
self
.
hf_config
assert
cfg
.
static_position_embeddings
model
=
MarianMTModel
(
cfg
)
assert
"hidden_size"
not
in
cfg
.
to_dict
()
load_layers_
(
model
.
model
.
encoder
.
layers
,
state_dict
,
BART_CONVERTER
,
)
load_layers_
(
model
.
model
.
decoder
.
layers
,
state_dict
,
BART_CONVERTER
,
is_decoder
=
True
)
# handle tensors not associated with layers
wemb_tensor
=
torch
.
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
wemb
))
bias_tensor
=
torch
.
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
final_bias
))
model
.
model
.
shared
.
weight
=
wemb_tensor
model
.
model
.
encoder
.
embed_tokens
=
model
.
model
.
decoder
.
embed_tokens
=
model
.
model
.
shared
model
.
final_logits_bias
=
bias_tensor
if
"Wpos"
in
state_dict
:
print
(
"Unexpected: got Wpos"
)
wpos_tensor
=
torch
.
tensor
(
state_dict
[
"Wpos"
])
model
.
model
.
encoder
.
embed_positions
.
weight
=
wpos_tensor
model
.
model
.
decoder
.
embed_positions
.
weight
=
wpos_tensor
if
cfg
.
normalize_embedding
:
assert
"encoder_emb_ln_scale_pre"
in
state_dict
raise
NotImplementedError
(
"Need to convert layernorm_embedding"
)
assert
not
self
.
extra_keys
,
f
"Failed to convert
{
self
.
extra_keys
}
"
assert
model
.
model
.
shared
.
padding_idx
==
self
.
pad_token_id
return
model
def
download_and_unzip
(
url
,
dest_dir
):
try
:
import
wget
except
ImportError
:
raise
ImportError
(
"you must pip install wget"
)
filename
=
wget
.
download
(
url
)
unzip
(
filename
,
dest_dir
)
os
.
remove
(
filename
)
def
main
(
source_dir
,
dest_dir
):
dest_dir
=
Path
(
dest_dir
)
dest_dir
.
mkdir
(
exist_ok
=
True
)
add_special_tokens_to_vocab
(
source_dir
)
tokenizer
=
MarianSentencePieceTokenizer
.
from_pretrained
(
str
(
source_dir
))
save_tokenizer
(
tokenizer
,
dest_dir
)
opus_state
=
OpusState
(
source_dir
)
assert
opus_state
.
cfg
[
"vocab_size"
]
==
len
(
tokenizer
.
encoder
)
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
# ^^ Save human readable marian config for debugging
model
=
opus_state
.
load_marian_model
()
model
.
save_pretrained
(
dest_dir
)
model
.
from_pretrained
(
dest_dir
)
# sanity check
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--src"
,
type
=
str
,
help
=
"path to marian model dir"
,
default
=
"en-de"
)
parser
.
add_argument
(
"--dest"
,
type
=
str
,
default
=
None
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
source_dir
=
Path
(
args
.
src
)
assert
source_dir
.
exists
()
dest_dir
=
f
"converted-
{
source_dir
.
name
}
"
if
args
.
dest
is
None
else
args
.
dest
main
(
source_dir
,
dest_dir
)
def
load_yaml
(
path
):
import
yaml
with
open
(
path
)
as
f
:
return
yaml
.
load
(
f
,
Loader
=
yaml
.
BaseLoader
)
def
save_json
(
content
:
Union
[
Dict
,
List
],
path
:
str
)
->
None
:
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
content
,
f
)
def
unzip
(
zip_path
:
str
,
dest_dir
:
str
)
->
None
:
with
ZipFile
(
zip_path
,
"r"
)
as
zipObj
:
zipObj
.
extractall
(
dest_dir
)
src/transformers/modeling_bart.py
View file @
847e7f33
...
...
@@ -18,6 +18,7 @@ import math
import
random
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
...
...
@@ -125,7 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
if
isinstance
(
module
,
nn
.
Embedding
):
elif
isinstance
(
module
,
SinusoidalPositionalEmbedding
):
pass
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
...
...
@@ -250,10 +253,16 @@ class BartEncoder(nn.Module):
self
.
max_source_positions
=
config
.
max_position_embeddings
self
.
embed_tokens
=
embed_tokens
self
.
embed_positions
=
LearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
self
.
padding_idx
,)
if
config
.
static_position_embeddings
:
self
.
embed_positions
=
SinusoidalPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
self
.
padding_idx
)
else
:
self
.
embed_positions
=
LearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
self
.
padding_idx
,
)
self
.
layers
=
nn
.
ModuleList
([
EncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
# mbart has one extra layer_norm
self
.
layer_norm
=
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_before
else
None
...
...
@@ -422,13 +431,18 @@ class BartDecoder(nn.Module):
self
.
max_target_positions
=
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
embed_tokens
self
.
embed_positions
=
LearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
self
.
padding_idx
,
)
if
config
.
static_position_embeddings
:
self
.
embed_positions
=
SinusoidalPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
config
.
pad_token_id
)
else
:
self
.
embed_positions
=
LearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
self
.
padding_idx
,
)
self
.
layers
=
nn
.
ModuleList
(
[
DecoderLayer
(
config
)
for
_
in
range
(
config
.
decoder_layers
)]
)
# type: List[DecoderLayer]
self
.
layernorm_embedding
=
LayerNorm
(
config
.
d_model
)
self
.
layernorm_embedding
=
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
self
.
layer_norm
=
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
def
forward
(
...
...
@@ -470,7 +484,7 @@ class BartDecoder(nn.Module):
if
use_cache
:
input_ids
=
input_ids
[:,
-
1
:]
positions
=
positions
[:,
-
1
:]
# happens after we embed them
assert
input_ids
.
ne
(
self
.
padding_idx
).
any
()
#
assert input_ids.ne(self.padding_idx).any()
x
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
x
+=
positions
...
...
@@ -859,6 +873,22 @@ class BartForConditionalGeneration(PretrainedBartModel):
super
().
__init__
(
config
)
base_model
=
BartModel
(
config
)
self
.
model
=
base_model
self
.
register_buffer
(
"final_logits_bias"
,
torch
.
zeros
((
1
,
self
.
model
.
shared
.
num_embeddings
)))
def
resize_token_embeddings
(
self
,
new_num_tokens
:
int
)
->
nn
.
Embedding
:
old_num_tokens
=
self
.
model
.
shared
.
num_embeddings
new_embeddings
=
super
().
resize_token_embeddings
(
new_num_tokens
)
self
.
model
.
shared
=
new_embeddings
self
.
_resize_final_logits_bias
(
new_num_tokens
,
old_num_tokens
)
return
new_embeddings
def
_resize_final_logits_bias
(
self
,
new_num_tokens
:
int
,
old_num_tokens
:
int
)
->
None
:
if
new_num_tokens
<=
old_num_tokens
:
new_bias
=
self
.
final_logits_bias
[:,
:
new_num_tokens
]
else
:
extra_bias
=
torch
.
zeros
((
1
,
new_num_tokens
-
old_num_tokens
))
new_bias
=
torch
.
cat
([
self
.
final_logits_bias
,
extra_bias
],
dim
=
1
)
self
.
register_buffer
(
"final_logits_bias"
,
new_bias
)
@
add_start_docstrings_to_callable
(
BART_INPUTS_DOCSTRING
)
def
forward
(
...
...
@@ -923,8 +953,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states
=
decoder_cached_states
,
use_cache
=
use_cache
,
)
lm_logits
=
F
.
linear
(
outputs
[
0
],
self
.
model
.
shared
.
weight
)
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
lm_logits
=
F
.
linear
(
outputs
[
0
],
self
.
model
.
shared
.
weight
,
bias
=
self
.
final_logits_bias
)
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add
cache,
hidden states and attention if they are here
if
lm_labels
is
not
None
:
loss_fct
=
nn
.
CrossEntropyLoss
()
# TODO(SS): do we need to ignore pad tokens in lm_labels?
...
...
@@ -957,6 +987,18 @@ class BartForConditionalGeneration(PretrainedBartModel):
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
eos_token_id
)
return
scores
def
_force_token_ids_generation
(
self
,
scores
,
token_ids
)
->
None
:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
all_but_token_ids_mask
=
torch
.
tensor
(
[
x
for
x
in
range
(
self
.
config
.
vocab_size
)
if
x
not
in
token_ids
],
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
,
)
assert
len
(
scores
.
shape
)
==
2
,
"scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores
[:,
all_but_token_ids_mask
]
=
-
float
(
"inf"
)
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
((
enc_out
,
enc_mask
),
decoder_cached_states
)
=
past
...
...
@@ -1061,3 +1103,39 @@ class BartForSequenceClassification(PretrainedBartModel):
outputs
=
(
loss
,)
+
outputs
return
outputs
class
SinusoidalPositionalEmbedding
(
nn
.
Embedding
):
"""This module produces sinusoidal positional embeddings of any length."""
def
__init__
(
self
,
num_positions
,
embedding_dim
,
padding_idx
=
None
):
super
().
__init__
(
num_positions
,
embedding_dim
)
if
embedding_dim
%
2
!=
0
:
raise
NotImplementedError
(
f
"odd embedding_dim
{
embedding_dim
}
not supported"
)
self
.
weight
=
self
.
_init_weight
(
self
.
weight
)
@
staticmethod
def
_init_weight
(
out
:
nn
.
Parameter
):
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
The cos features are in the 2nd half of the vector. [dim // 2:]
"""
n_pos
,
dim
=
out
.
shape
position_enc
=
np
.
array
(
[[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
for
pos
in
range
(
n_pos
)]
)
out
[:,
0
:
dim
//
2
]
=
torch
.
FloatTensor
(
np
.
sin
(
position_enc
[:,
0
::
2
]))
# This line breaks for odd n_pos
out
[:,
dim
//
2
:]
=
torch
.
FloatTensor
(
np
.
cos
(
position_enc
[:,
1
::
2
]))
out
.
detach_
()
out
.
requires_grad
=
False
return
out
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
,
use_cache
=
False
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz
,
seq_len
=
input_ids
.
shape
[:
2
]
if
use_cache
:
positions
=
input_ids
.
data
.
new
(
1
,
1
).
fill_
(
seq_len
-
1
)
# called before slicing
else
:
# starts at 0, ends at 1-seq_len
positions
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
weight
.
device
)
return
super
().
forward
(
positions
)
src/transformers/modeling_marian.py
0 → 100644
View file @
847e7f33
# coding=utf-8
# Copyright 2020 Marian Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MarianMTModel model, ported from the Marian C++ repo."""
from
transformers.modeling_bart
import
BartForConditionalGeneration
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"opus-mt-en-de"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/pytorch_model.bin"
,
}
class
MarianMTModel
(
BartForConditionalGeneration
):
"""Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
Model API is identical to BartForConditionalGeneration"""
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
def
prepare_scores_for_generation
(
self
,
scores
,
cur_len
,
max_length
):
if
cur_len
==
max_length
-
1
and
self
.
config
.
eos_token_id
is
not
None
:
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
eos_token_id
)
return
scores
src/transformers/modeling_utils.py
View file @
847e7f33
...
...
@@ -1530,18 +1530,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return
decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
def
_force_token_ids_generation
(
self
,
scores
,
token_ids
)
->
None
:
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
all_but_token_ids_mask
=
torch
.
tensor
(
[
x
for
x
in
range
(
self
.
config
.
vocab_size
)
if
x
not
in
token_ids
],
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
,
)
assert
len
(
scores
.
shape
)
==
2
,
"scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores
[:,
all_but_token_ids_mask
]
=
-
float
(
"inf"
)
@
staticmethod
def
_reorder_cache
(
past
:
Tuple
,
beam_idx
:
Tensor
)
->
Tuple
[
Tensor
]:
return
tuple
(
layer_past
.
index_select
(
1
,
beam_idx
)
for
layer_past
in
past
)
...
...
src/transformers/tokenization_marian.py
0 → 100644
View file @
847e7f33
import
json
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Union
import
sentencepiece
from
.file_utils
import
S3_BUCKET_PREFIX
from
.tokenization_utils
import
BatchEncoding
,
PreTrainedTokenizer
vocab_files_names
=
{
"source_spm"
:
"source.spm"
,
"target_spm"
:
"target.spm"
,
"vocab"
:
"vocab.json"
,
"tokenizer_config_file"
:
"tokenizer_config.json"
,
}
MODEL_NAMES
=
(
"opus-mt-en-de"
,)
PRETRAINED_VOCAB_FILES_MAP
=
{
k
:
{
m
:
f
"
{
S3_BUCKET_PREFIX
}
/Helsinki-NLP/
{
m
}
/
{
fname
}
"
for
m
in
MODEL_NAMES
}
for
k
,
fname
in
vocab_files_names
.
items
()
}
# Example URL https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/vocab.json
class
MarianSentencePieceTokenizer
(
PreTrainedTokenizer
):
vocab_files_names
=
vocab_files_names
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
{
m
:
512
for
m
in
MODEL_NAMES
}
model_input_names
=
[
"attention_mask"
]
# actually attention_mask, decoder_attention_mask
def
__init__
(
self
,
vocab
=
None
,
source_spm
=
None
,
target_spm
=
None
,
source_lang
=
None
,
target_lang
=
None
,
unk_token
=
"<unk>"
,
eos_token
=
"</s>"
,
pad_token
=
"<pad>"
,
max_len
=
512
,
):
super
().
__init__
(
# bos_token=bos_token,
max_len
=
max_len
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
)
self
.
encoder
=
load_json
(
vocab
)
assert
self
.
pad_token
in
self
.
encoder
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
source_lang
=
source_lang
self
.
target_lang
=
target_lang
# load SentencePiece model for pre-processing
self
.
paths
=
{}
self
.
spm_source
=
sentencepiece
.
SentencePieceProcessor
()
self
.
spm_source
.
Load
(
source_spm
)
self
.
spm_target
=
sentencepiece
.
SentencePieceProcessor
()
self
.
spm_target
.
Load
(
target_spm
)
# Note(SS): splitter would require lots of book-keeping.
# self.sentence_splitter = MosesSentenceSplitter(source_lang)
try
:
from
mosestokenizer
import
MosesPunctuationNormalizer
self
.
punc_normalizer
=
MosesPunctuationNormalizer
(
source_lang
)
except
ImportError
:
warnings
.
warn
(
"Recommended: pip install mosestokenizer"
)
self
.
punc_normalizer
=
lambda
x
:
x
def
_convert_token_to_id
(
self
,
token
):
return
self
.
encoder
[
token
]
def
_tokenize
(
self
,
text
:
str
,
src
=
True
)
->
List
[
str
]:
spm
=
self
.
spm_source
if
src
else
self
.
spm_target
return
spm
.
EncodeAsPieces
(
text
)
def
_convert_id_to_token
(
self
,
index
:
int
)
->
str
:
"""Converts an index (integer) in a token (str) using the encoder."""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
"""Uses target language sentencepiece model"""
return
self
.
spm_target
.
DecodePieces
(
tokens
)
def
_append_special_tokens_and_truncate
(
self
,
tokens
:
str
,
max_length
:
int
,)
->
List
[
int
]:
ids
:
list
=
self
.
convert_tokens_to_ids
(
tokens
)[:
max_length
]
return
ids
+
[
self
.
eos_token_id
]
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
)
->
List
[
int
]:
"""Build model inputs from a sequence by appending eos_token_id."""
if
token_ids_1
is
None
:
return
token_ids_0
+
[
self
.
eos_token_id
]
# We don't expect to process pairs, but leave the pair logic for API consistency
return
token_ids_0
+
token_ids_1
+
[
self
.
eos_token_id
]
def
decode_batch
(
self
,
token_ids
,
**
kwargs
)
->
List
[
str
]:
return
[
self
.
decode
(
ids
,
**
kwargs
)
for
ids
in
token_ids
]
def
prepare_translation_batch
(
self
,
src_texts
:
List
[
str
],
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
pad_to_max_length
:
bool
=
True
,
return_tensors
:
str
=
"pt"
,
)
->
BatchEncoding
:
"""
Arguments:
src_texts: list of src language texts
src_lang: default en_XX (english)
tgt_texts: list of tgt language texts
tgt_lang: default ro_RO (romanian)
max_length: (None) defer to config (1024 for mbart-large-en-ro)
pad_to_max_length: (bool)
Returns:
BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists)
Examples:
from transformers import MarianS
"""
model_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
src
=
True
,
)
if
tgt_texts
is
None
:
return
model_inputs
decoder_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
src
=
False
,
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
return
model_inputs
@
property
def
vocab_size
(
self
)
->
int
:
return
len
(
self
.
encoder
)
def
load_json
(
path
:
str
)
->
Union
[
Dict
,
List
]:
with
open
(
path
,
"r"
)
as
f
:
return
json
.
load
(
f
)
src/transformers/tokenization_utils.py
View file @
847e7f33
...
...
@@ -31,7 +31,7 @@ from tokenizers import Encoding as EncodingFast
from
tokenizers.decoders
import
Decoder
as
DecoderFast
from
tokenizers.implementations
import
BaseTokenizer
as
BaseTokenizerFast
from
.file_utils
import
cached_path
,
hf_bucket_url
,
is_remote_url
,
is_tf_available
,
is_torch_available
from
.file_utils
import
cached_path
,
hf_bucket_url
,
is_remote_url
,
is_tf_available
,
is_torch_available
,
torch_required
if
is_tf_available
():
...
...
@@ -458,6 +458,12 @@ class BatchEncoding(UserDict):
char_index
=
batch_or_char_index
return
self
.
_encodings
[
batch_index
].
char_to_word
(
char_index
)
@
torch_required
def
to
(
self
,
device
:
str
):
"""Send all values to device by calling v.to(device)"""
self
.
data
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
self
.
data
.
items
()}
return
self
class
SpecialTokensMixin
:
""" SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and
...
...
tests/test_modeling_bart.py
View file @
847e7f33
...
...
@@ -42,6 +42,7 @@ if is_torch_available():
shift_tokens_right
,
invert_mask
,
_prepare_bart_decoder_inputs
,
SinusoidalPositionalEmbedding
,
)
...
...
@@ -650,3 +651,41 @@ class BartModelIntegrationTests(unittest.TestCase):
)
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length
@
require_torch
class
TestSinusoidalPositionalEmbeddings
(
unittest
.
TestCase
):
desired_weights
=
[
[
0
,
0
,
0
,
0
,
0
],
[
0.84147096
,
0.82177866
,
0.80180490
,
0.78165019
,
0.76140374
],
[
0.90929741
,
0.93651021
,
0.95829457
,
0.97505713
,
0.98720258
],
]
def
test_positional_emb_cache_logic
(
self
):
pad
=
1
input_ids
=
torch
.
tensor
([[
4
,
10
]],
dtype
=
torch
.
long
,
device
=
torch_device
)
emb1
=
SinusoidalPositionalEmbedding
(
num_positions
=
32
,
embedding_dim
=
6
,
padding_idx
=
pad
).
to
(
torch_device
)
no_cache
=
emb1
(
input_ids
,
use_cache
=
False
)
yes_cache
=
emb1
(
input_ids
,
use_cache
=
True
)
self
.
assertEqual
((
1
,
1
,
6
),
yes_cache
.
shape
)
# extra dim to allow broadcasting, feel free to delete!
self
.
assertListEqual
(
no_cache
[
-
1
].
tolist
(),
yes_cache
[
0
][
0
].
tolist
())
def
test_odd_embed_dim
(
self
):
with
self
.
assertRaises
(
NotImplementedError
):
SinusoidalPositionalEmbedding
(
num_positions
=
4
,
embedding_dim
=
5
,
padding_idx
=
0
).
to
(
torch_device
)
# odd num_positions is allowed
SinusoidalPositionalEmbedding
(
num_positions
=
5
,
embedding_dim
=
4
,
padding_idx
=
0
).
to
(
torch_device
)
def
test_positional_emb_weights_against_marian
(
self
):
pad
=
1
emb1
=
SinusoidalPositionalEmbedding
(
num_positions
=
512
,
embedding_dim
=
512
,
padding_idx
=
pad
).
to
(
torch_device
)
weights
=
emb1
.
weight
.
data
[:
3
,
:
5
].
tolist
()
for
i
,
(
expected_weight
,
actual_weight
)
in
enumerate
(
zip
(
self
.
desired_weights
,
weights
)):
for
j
in
range
(
5
):
self
.
assertAlmostEqual
(
expected_weight
[
j
],
actual_weight
[
j
],
places
=
3
)
# test that forward pass is just a lookup, there is no ignore padding logic
input_ids
=
torch
.
tensor
([[
4
,
10
,
pad
,
pad
,
pad
]],
dtype
=
torch
.
long
,
device
=
torch_device
)
no_cache_pad_zero
=
emb1
(
input_ids
)
self
.
assertTrue
(
torch
.
allclose
(
torch
.
Tensor
(
self
.
desired_weights
),
no_cache_pad_zero
[:
3
,
:
5
],
atol
=
1e-3
))
tests/test_modeling_marian.py
0 → 100644
View file @
847e7f33
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
transformers
import
is_torch_available
from
transformers.file_utils
import
cached_property
from
.utils
import
require_torch
,
slow
,
torch_device
if
is_torch_available
():
import
torch
from
transformers
import
MarianMTModel
,
MarianSentencePieceTokenizer
@
require_torch
class
IntegrationTests
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
cls
.
model_name
=
"Helsinki-NLP/opus-mt-en-de"
cls
.
tokenizer
=
MarianSentencePieceTokenizer
.
from_pretrained
(
cls
.
model_name
)
cls
.
eos_token_id
=
cls
.
tokenizer
.
eos_token_id
return
cls
@
cached_property
def
model
(
self
):
model
=
MarianMTModel
.
from_pretrained
(
self
.
model_name
).
to
(
torch_device
)
if
torch_device
==
"cuda"
:
return
model
.
half
()
else
:
return
model
@
slow
def
test_forward
(
self
):
src
,
tgt
=
[
"I am a small frog"
],
[
"▁Ich ▁bin ▁ein ▁kleiner ▁Fro sch"
]
expected
=
[
38
,
121
,
14
,
697
,
38848
,
0
]
model_inputs
:
dict
=
self
.
tokenizer
.
prepare_translation_batch
(
src
,
tgt_texts
=
tgt
).
to
(
torch_device
)
self
.
assertListEqual
(
expected
,
model_inputs
[
"input_ids"
][
0
].
tolist
())
desired_keys
=
{
"input_ids"
,
"attention_mask"
,
"decoder_input_ids"
,
"decoder_attention_mask"
,
}
self
.
assertSetEqual
(
desired_keys
,
set
(
model_inputs
.
keys
()))
with
torch
.
no_grad
():
logits
,
*
enc_features
=
self
.
model
(
**
model_inputs
)
max_indices
=
logits
.
argmax
(
-
1
)
self
.
tokenizer
.
decode_batch
(
max_indices
)
@
slow
def
test_repl_generate_one
(
self
):
src
=
[
"I am a small frog."
,
"Hello"
]
model_inputs
:
dict
=
self
.
tokenizer
.
prepare_translation_batch
(
src
).
to
(
torch_device
)
self
.
assertEqual
(
self
.
model
.
device
,
model_inputs
[
"input_ids"
].
device
)
generated_ids
=
self
.
model
.
generate
(
model_inputs
[
"input_ids"
],
num_beams
=
6
,)
generated_words
=
self
.
tokenizer
.
decode_batch
(
generated_ids
)[
0
]
expected_words
=
"Ich bin ein kleiner Frosch."
self
.
assertEqual
(
expected_words
,
generated_words
)
@
slow
def
test_repl_generate_batch
(
self
):
src
=
[
"I am a small frog."
,
"Now I can forget the 100 words of german that I know."
,
"O"
,
"Tom asked his teacher for advice."
,
"That's how I would do it."
,
"Tom really admired Mary's courage."
,
"Turn around and close your eyes."
,
]
model_inputs
:
dict
=
self
.
tokenizer
.
prepare_translation_batch
(
src
).
to
(
torch_device
)
self
.
assertEqual
(
self
.
model
.
device
,
model_inputs
[
"input_ids"
].
device
)
generated_ids
=
self
.
model
.
generate
(
model_inputs
[
"input_ids"
],
length_penalty
=
1.0
,
num_beams
=
2
,
# 6 is the default
bad_words_ids
=
[[
self
.
tokenizer
.
pad_token_id
]],
)
expected
=
[
"Ich bin ein kleiner Frosch."
,
"Jetzt kann ich die 100 Wörter des Deutschen vergessen, die ich kenne."
,
""
,
"Tom bat seinen Lehrer um Rat."
,
"So würde ich das tun."
,
"Tom bewunderte Marias Mut wirklich."
,
"Umdrehen und die Augen schließen."
,
]
# actual C++ output differences: (1) des Deutschen removed, (2) ""-> "O", (3) tun -> machen
generated_words
=
self
.
tokenizer
.
decode_batch
(
generated_ids
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected
,
generated_words
)
def
test_marian_equivalence
(
self
):
batch
=
self
.
tokenizer
.
prepare_translation_batch
([
"I am a small frog"
]).
to
(
torch_device
)
input_ids
=
batch
[
"input_ids"
][
0
]
expected
=
[
38
,
121
,
14
,
697
,
38848
,
0
]
self
.
assertListEqual
(
expected
,
input_ids
.
tolist
())
def
test_pad_not_split
(
self
):
input_ids_w_pad
=
self
.
tokenizer
.
prepare_translation_batch
([
"I am a small frog <pad>"
])[
"input_ids"
][
0
]
expected_w_pad
=
[
38
,
121
,
14
,
697
,
38848
,
self
.
tokenizer
.
pad_token_id
,
0
]
# pad
self
.
assertListEqual
(
expected_w_pad
,
input_ids_w_pad
.
tolist
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment