Unverified Commit 8ff88d25 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt] rewrite SinusoidalPositionalEmbedding + USE_CUDA test fixes + new...


[fsmt] rewrite SinusoidalPositionalEmbedding + USE_CUDA test fixes + new TranslationPipeline test (#7224)

* fix USE_CUDA, add pipeline

* USE_CUDA fix

* recode SinusoidalPositionalEmbedding into nn.Embedding subclass

was needed for torchscript to work - this is now part of the state_dict, so will have to remove these keys during save_pretrained

* back out (ci debug)

* restore

* slow last?

* facilitate not saving certain keys and test

* remove no longer used keys

* style

* fix logging import

* cleanup

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* fix bug in max_positional_embeddings

* rename keys to keys_to_never_save per suggestion, improve the setup

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 67c4b0c5
...@@ -31,16 +31,14 @@ import torch ...@@ -31,16 +31,14 @@ import torch
from fairseq import hub_utils from fairseq import hub_utils
from fairseq.data.dictionary import Dictionary from fairseq.data.dictionary import Dictionary
from transformers import WEIGHTS_NAME from transformers import WEIGHTS_NAME, logging
from transformers.configuration_fsmt import FSMTConfig from transformers.configuration_fsmt import FSMTConfig
from transformers.modeling_fsmt import FSMTForConditionalGeneration from transformers.modeling_fsmt import FSMTForConditionalGeneration
from transformers.tokenization_fsmt import VOCAB_FILES_NAMES from transformers.tokenization_fsmt import VOCAB_FILES_NAMES
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
from .utils import logging
logging.set_verbosity_warning()
logging.set_verbosity_info()
json_indent = 2 json_indent = 2
...@@ -229,6 +227,8 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -229,6 +227,8 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
"model.decoder.version", "model.decoder.version",
"model.encoder_embed_tokens.weight", "model.encoder_embed_tokens.weight",
"model.decoder_embed_tokens.weight", "model.decoder_embed_tokens.weight",
"model.encoder.embed_positions._float_tensor",
"model.decoder.embed_positions._float_tensor",
] ]
for k in ignore_keys: for k in ignore_keys:
model_state_dict.pop(k, None) model_state_dict.pop(k, None)
......
...@@ -397,24 +397,18 @@ class FSMTEncoder(nn.Module): ...@@ -397,24 +397,18 @@ class FSMTEncoder(nn.Module):
def __init__(self, config: FSMTConfig, embed_tokens): def __init__(self, config: FSMTConfig, embed_tokens):
super().__init__() super().__init__()
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop self.layerdrop = config.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.padding_idx = embed_tokens.padding_idx self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
# print(config.max_position_embeddings, embed_dim, self.padding_idx) embed_dim = embed_tokens.embedding_dim
num_embeddings = config.src_vocab_size self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_positions = SinusoidalPositionalEmbedding( self.embed_positions = SinusoidalPositionalEmbedding(
embed_dim, config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
self.padding_idx,
init_size=num_embeddings + self.padding_idx + 1, # removed: config.max_position_embeddings
) )
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList(
[EncoderLayer(config) for _ in range(config.encoder_layers)]
) # type: List[EncoderLayer]
def forward( def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
...@@ -570,15 +564,11 @@ class FSMTDecoder(nn.Module): ...@@ -570,15 +564,11 @@ class FSMTDecoder(nn.Module):
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
num_embeddings = config.tgt_vocab_size
self.embed_positions = SinusoidalPositionalEmbedding( self.embed_positions = SinusoidalPositionalEmbedding(
embed_dim, config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
self.padding_idx,
init_size=num_embeddings + self.padding_idx + 1,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)] [DecoderLayer(config) for _ in range(config.decoder_layers)]
...@@ -1003,6 +993,14 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -1003,6 +993,14 @@ class FSMTModel(PretrainedFSMTModel):
) )
class FSMTForConditionalGeneration(PretrainedFSMTModel): class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
def __init__(self, config: FSMTConfig): def __init__(self, config: FSMTConfig):
super().__init__(config) super().__init__(config)
...@@ -1137,36 +1135,34 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1137,36 +1135,34 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
return self.model.decoder.embed_tokens return self.model.decoder.embed_tokens
def make_positions(tensor, padding_idx: int): class SinusoidalPositionalEmbedding(nn.Embedding):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
""" """
# The series of casts and type-conversions here are carefully This module produces sinusoidal positional embeddings of any length.
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
We don't want to save the weight of this embedding since it's not trained
class SinusoidalPositionalEmbedding(nn.Module): (deterministic) and it can be huge.
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored. Padding symbols are ignored.
These embeddings get automatically extended in forward if more positions is needed.
""" """
def __init__(self, embedding_dim, padding_idx, init_size=1024): def __init__(self, num_positions, embedding_dim, padding_idx):
super().__init__() self.make_weight(num_positions, embedding_dim, padding_idx)
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx def make_weight(self, num_positions, embedding_dim, padding_idx, device=None):
self.weights = SinusoidalPositionalEmbedding.get_embedding(init_size, embedding_dim, padding_idx) weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
self.register_buffer("_float_tensor", torch.zeros(1)) # used for getting the right device if device is not None:
self.max_positions = int(1e5) weight = weight.to(device)
if not hasattr(self, "weight"):
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
else:
self.weight = nn.Parameter(weight)
self.weight.detach_()
self.weight.requires_grad = False
# XXX: bart uses s/num_embeddings/num_positions/, s/weights/weight/ - could make those match
@staticmethod @staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): def get_embedding(num_embeddings, embedding_dim, padding_idx):
"""Build sinusoidal embeddings. """Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly This matches the implementation in tensor2tensor, but differs slightly
...@@ -1184,28 +1180,30 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -1184,28 +1180,30 @@ class SinusoidalPositionalEmbedding(nn.Module):
emb[padding_idx, :] = 0 emb[padding_idx, :] = 0
return emb return emb
@staticmethod
def make_positions(tensor, padding_idx: int):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
def forward( def forward(
self, self,
input, input,
incremental_state: Optional[Any] = None, incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None, timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
): ):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
# bspair = torch.onnx.operators.shape_as_tensor(input)
# bsz, seq_len = bspair[0], bspair[1]
bsz, seq_len = input.shape[:2] bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0): if max_pos > self.weight.size(0):
# recompute/expand embeddings if needed # expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(max_pos, self.embedding_dim, self.padding_idx) self.make_weight(max_pos, self.embedding_dim, self.padding_idx, device=input.device)
self.weights = self.weights.to(self._float_tensor) positions = self.make_positions(input, self.padding_idx)
return super().forward(positions)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = make_positions(input, self.padding_idx)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
...@@ -391,10 +391,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -391,10 +391,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings). when loading the model (and avoid unnecessary warnings).
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore
when saving the model (useful for keys that aren't trained, but which are deterministic)
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None authorized_missing_keys = None
keys_to_never_save = None
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
...@@ -688,6 +692,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -688,6 +692,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Attach architecture to the config # Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__] model_to_save.config.architectures = [model_to_save.__class__.__name__]
state_dict = model_to_save.state_dict()
# Handle the case where some state_dict keys shouldn't be saved
if self.keys_to_never_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
...@@ -698,10 +708,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -698,10 +708,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Save configuration file # Save configuration file
model_to_save.config.save_pretrained(save_directory) model_to_save.config.save_pretrained(save_directory)
# xm.save takes care of saving only from master # xm.save takes care of saving only from master
xm.save(model_to_save.state_dict(), output_model_file) xm.save(state_dict, output_model_file)
else: else:
model_to_save.config.save_pretrained(save_directory) model_to_save.config.save_pretrained(save_directory)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(state_dict, output_model_file)
logger.info("Model weights saved in {}".format(output_model_file)) logger.info("Model weights saved in {}".format(output_model_file))
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import tempfile import tempfile
import unittest import unittest
...@@ -20,7 +21,7 @@ import timeout_decorator # noqa ...@@ -20,7 +21,7 @@ import timeout_decorator # noqa
from parameterized import parameterized from parameterized import parameterized
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import WEIGHTS_NAME, cached_property
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -37,6 +38,7 @@ if is_torch_available(): ...@@ -37,6 +38,7 @@ if is_torch_available():
invert_mask, invert_mask,
shift_tokens_right, shift_tokens_right,
) )
from transformers.pipelines import TranslationPipeline
@require_torch @require_torch
...@@ -207,6 +209,27 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -207,6 +209,27 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], []) self.assertEqual(info["missing_keys"], [])
def test_save_load_no_save_keys(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
state_dict_no_save_keys = getattr(model, "state_dict_no_save_keys", None)
if state_dict_no_save_keys is None:
continue
# check the keys are in the original state_dict
for k in state_dict_no_save_keys:
self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file)
for k in state_dict_no_save_keys:
self.assertNotIn(k, state_dict_saved)
@unittest.skip("can't be implemented for FSMT due to dual vocab.") @unittest.skip("can't be implemented for FSMT due to dual vocab.")
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
...@@ -219,14 +242,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -219,14 +242,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
@unittest.skip("failing on CI - needs review")
def test_torchscript_output_attentions(self):
pass
@unittest.skip("failing on CI - needs review")
def test_torchscript_output_hidden_state(self):
pass
# def test_auto_model(self): # def test_auto_model(self):
# # XXX: add a tiny model to s3? # # XXX: add a tiny model to s3?
# model_name = "facebook/wmt19-ru-en-tiny" # model_name = "facebook/wmt19-ru-en-tiny"
...@@ -366,6 +381,14 @@ def _long_tensor(tok_lst): ...@@ -366,6 +381,14 @@ def _long_tensor(tok_lst):
TOLERANCE = 1e-4 TOLERANCE = 1e-4
pairs = [
["en-ru"],
["ru-en"],
["en-de"],
["de-en"],
]
@require_torch @require_torch
class FSMTModelIntegrationTests(unittest.TestCase): class FSMTModelIntegrationTests(unittest.TestCase):
tokenizers_cache = {} tokenizers_cache = {}
...@@ -399,7 +422,7 @@ class FSMTModelIntegrationTests(unittest.TestCase): ...@@ -399,7 +422,7 @@ class FSMTModelIntegrationTests(unittest.TestCase):
src_text = "My friend computer will translate this for me" src_text = "My friend computer will translate this for me"
input_ids = tokenizer([src_text], return_tensors="pt")["input_ids"] input_ids = tokenizer([src_text], return_tensors="pt")["input_ids"]
input_ids = _long_tensor(input_ids) input_ids = _long_tensor(input_ids).to(torch_device)
inputs_dict = prepare_fsmt_inputs_dict(model.config, input_ids) inputs_dict = prepare_fsmt_inputs_dict(model.config, input_ids)
with torch.no_grad(): with torch.no_grad():
output = model(**inputs_dict)[0] output = model(**inputs_dict)[0]
...@@ -409,19 +432,10 @@ class FSMTModelIntegrationTests(unittest.TestCase): ...@@ -409,19 +432,10 @@ class FSMTModelIntegrationTests(unittest.TestCase):
# may have to adjust if switched to a different checkpoint # may have to adjust if switched to a different checkpoint
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]] [[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]]
) ).to(torch_device)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
@parameterized.expand( def translation_setup(self, pair):
[
["en-ru"],
["ru-en"],
["en-de"],
["de-en"],
]
)
@slow
def test_translation(self, pair):
text = { text = {
"en": "Machine learning is great, isn't it?", "en": "Machine learning is great, isn't it?",
"ru": "Машинное обучение - это здорово, не так ли?", "ru": "Машинное обучение - это здорово, не так ли?",
...@@ -432,16 +446,32 @@ class FSMTModelIntegrationTests(unittest.TestCase): ...@@ -432,16 +446,32 @@ class FSMTModelIntegrationTests(unittest.TestCase):
print(f"Testing {src} -> {tgt}") print(f"Testing {src} -> {tgt}")
mname = f"facebook/wmt19-{pair}" mname = f"facebook/wmt19-{pair}"
src_sentence = text[src] src_text = text[src]
tgt_sentence = text[tgt] tgt_text = text[tgt]
tokenizer = self.get_tokenizer(mname) tokenizer = self.get_tokenizer(mname)
model = self.get_model(mname) model = self.get_model(mname)
return tokenizer, model, src_text, tgt_text
@parameterized.expand(pairs)
@slow
def test_translation_direct(self, pair):
tokenizer, model, src_text, tgt_text = self.translation_setup(pair)
input_ids = tokenizer.encode(src_text, return_tensors="pt").to(torch_device)
input_ids = tokenizer.encode(src_sentence, return_tensors="pt")
outputs = model.generate(input_ids) outputs = model.generate(input_ids)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
assert decoded == tgt_sentence, f"\n\ngot: {decoded}\nexp: {tgt_sentence}\n" assert decoded == tgt_text, f"\n\ngot: {decoded}\nexp: {tgt_text}\n"
@parameterized.expand(pairs)
@slow
def test_translation_pipeline(self, pair):
tokenizer, model, src_text, tgt_text = self.translation_setup(pair)
device = 0 if torch_device == "cuda" else -1
pipeline = TranslationPipeline(model, tokenizer, framework="pt", device=device)
output = pipeline([src_text])
self.assertEqual([tgt_text], [x["translation_text"] for x in output])
@require_torch @require_torch
...@@ -449,10 +479,9 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -449,10 +479,9 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
padding_idx = 1 padding_idx = 1
tolerance = 1e-4 tolerance = 1e-4
@unittest.skip("failing on CI - needs review")
def test_basic(self): def test_basic(self):
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device) input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
emb1 = SinusoidalPositionalEmbedding(embedding_dim=6, padding_idx=self.padding_idx, init_size=6).to( emb1 = SinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6, padding_idx=self.padding_idx).to(
torch_device torch_device
) )
emb = emb1(input_ids) emb = emb1(input_ids)
...@@ -461,7 +490,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -461,7 +490,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
[9.0930e-01, 1.9999e-02, 2.0000e-04, -4.1615e-01, 9.9980e-01, 1.0000e00], [9.0930e-01, 1.9999e-02, 2.0000e-04, -4.1615e-01, 9.9980e-01, 1.0000e00],
[1.4112e-01, 2.9995e-02, 3.0000e-04, -9.8999e-01, 9.9955e-01, 1.0000e00], [1.4112e-01, 2.9995e-02, 3.0000e-04, -9.8999e-01, 9.9955e-01, 1.0000e00],
] ]
) ).to(torch_device)
self.assertTrue( self.assertTrue(
torch.allclose(emb[0], desired_weights, atol=self.tolerance), torch.allclose(emb[0], desired_weights, atol=self.tolerance),
msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n", msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n",
...@@ -469,14 +498,10 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -469,14 +498,10 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
def test_odd_embed_dim(self): def test_odd_embed_dim(self):
# odd embedding_dim is allowed # odd embedding_dim is allowed
SinusoidalPositionalEmbedding.get_embedding( SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=self.padding_idx).to(torch_device)
num_embeddings=4, embedding_dim=5, padding_idx=self.padding_idx
).to(torch_device)
# odd num_embeddings is allowed # odd num_embeddings is allowed
SinusoidalPositionalEmbedding.get_embedding( SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=self.padding_idx).to(torch_device)
num_embeddings=5, embedding_dim=4, padding_idx=self.padding_idx
).to(torch_device)
@unittest.skip("different from marian (needs more research)") @unittest.skip("different from marian (needs more research)")
def test_positional_emb_weights_against_marian(self): def test_positional_emb_weights_against_marian(self):
...@@ -488,7 +513,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -488,7 +513,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
[0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258], [0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
] ]
) )
emb1 = SinusoidalPositionalEmbedding(init_size=512, embedding_dim=512, padding_idx=self.padding_idx).to( emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=self.padding_idx).to(
torch_device torch_device
) )
weights = emb1.weights.data[:3, :5] weights = emb1.weights.data[:3, :5]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment