Commit a8f28ecb authored by Bairen Yi's avatar Bairen Yi Committed by Facebook Github Bot
Browse files

Python3.5 compat (#794)

Summary:
See #467. Ping myleott to review.

This is a work-related contribution. Ping lark to review.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/794

Differential Revision: D15756816

Pulled By: myleott

fbshipit-source-id: 6dce3ff3a713bf5f60e5782bc260b2ca9d2c0a9b
parent 9b40999e
...@@ -39,7 +39,7 @@ translation and language modeling datasets. ...@@ -39,7 +39,7 @@ translation and language modeling datasets.
# Requirements and Installation # Requirements and Installation
* [PyTorch](http://pytorch.org/) version >= 1.0.0 * [PyTorch](http://pytorch.org/) version >= 1.0.0
* Python version >= 3.6 * Python version >= 3.5
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
Please follow the instructions here to install PyTorch: https://github.com/pytorch/pytorch#installation. Please follow the instructions here to install PyTorch: https://github.com/pytorch/pytorch#installation.
......
...@@ -166,8 +166,8 @@ def batch_by_size( ...@@ -166,8 +166,8 @@ def batch_by_size(
sample_lens.append(num_tokens_fn(idx)) sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1]) sample_len = max(sample_len, sample_lens[-1])
assert sample_len <= max_tokens, ( assert sample_len <= max_tokens, (
f"sentence at index {idx} of size {sample_len} exceeds max_tokens " "sentence at index {} of size {} exceeds max_tokens "
f"limit of {max_tokens}!" "limit of {}!".format(idx, sample_len, max_tokens)
) )
num_tokens = (len(batch) + 1) * sample_len num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens): if is_batch_full(num_tokens):
......
...@@ -280,7 +280,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -280,7 +280,7 @@ class TransformerEncoder(FairseqEncoder):
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)): for i in range(len(self.layers)):
# update layer norms # update layer norms
self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}") self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))
version_key = '{}.version'.format(name) version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
...@@ -540,10 +540,10 @@ class TransformerEncoderLayer(nn.Module): ...@@ -540,10 +540,10 @@ class TransformerEncoderLayer(nn.Module):
} }
for old, new in layer_norm_map.items(): for old, new in layer_norm_map.items():
for m in ('weight', 'bias'): for m in ('weight', 'bias'):
k = f'{name}.layer_norms.{old}.{m}' k = '{}.layer_norms.{}.{}'.format(name, old, m)
if k in state_dict: if k in state_dict:
state_dict[ state_dict[
f'{name}.{new}.{m}' '{}.{}.{}'.format(name, new, m)
] = state_dict[k] ] = state_dict[k]
del state_dict[k] del state_dict[k]
......
...@@ -90,7 +90,7 @@ def upgrade_state_dict_with_xlm_weights( ...@@ -90,7 +90,7 @@ def upgrade_state_dict_with_xlm_weights(
decoder and the pretrained_xlm_checkpoint decoder and the pretrained_xlm_checkpoint
""" """
if not os.path.exists(pretrained_xlm_checkpoint): if not os.path.exists(pretrained_xlm_checkpoint):
raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}") raise IOError("Model file not found: {}".format(pretrained_xlm_checkpoint))
state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
xlm_state_dict = state["model"] xlm_state_dict = state["model"]
...@@ -100,10 +100,12 @@ def upgrade_state_dict_with_xlm_weights( ...@@ -100,10 +100,12 @@ def upgrade_state_dict_with_xlm_weights(
if search_key in key: if search_key in key:
subkey = key[key.find(search_key):] subkey = key[key.find(search_key):]
assert subkey in state_dict, ( assert subkey in state_dict, (
f"{str(state_dict.keys())} Transformer encoder / decoder " "{} Transformer encoder / decoder "
f"state_dict does not contain {subkey}. Cannot " "state_dict does not contain {}. Cannot "
f"load {key} from pretrained XLM checkpoint " "load {} from pretrained XLM checkpoint "
f"{pretrained_xlm_checkpoint} into Transformer." "{} into Transformer.".format(
str(state_dict.keys()),
subkey, key, pretrained_xlm_checkpoint)
) )
state_dict[subkey] = xlm_state_dict[key] state_dict[subkey] = xlm_state_dict[key]
......
...@@ -27,14 +27,14 @@ from . import FairseqTask, register_task ...@@ -27,14 +27,14 @@ from . import FairseqTask, register_task
def _lang_token(lang: str): def _lang_token(lang: str):
return f'__{lang}__' return '__{}__'.format(lang)
def _lang_token_index(dic: Dictionary, lang: str): def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index.""" """Return language token index."""
idx = dic.index(_lang_token(lang)) idx = dic.index(_lang_token(lang))
assert idx != dic.unk_index, \ assert idx != dic.unk_index, \
f'cannot find language token for lang {lang}' 'cannot find language token for lang {}'.format(lang)
return idx return idx
......
...@@ -362,7 +362,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -362,7 +362,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
for lang_pair in self.args.lang_pairs: for lang_pair in self.args.lang_pairs:
_, tgt = lang_pair.split('-') _, tgt = lang_pair.split('-')
sample_key = _get_denoising_dataset_key(lang_pair) sample_key = _get_denoising_dataset_key(lang_pair)
forward_backward(model.models[f'{tgt}-{tgt}'], sample[sample_key], sample_key, self.lambda_denoising) forward_backward(model.models['{0}-{0}'.format(tgt)], sample[sample_key], sample_key, self.lambda_denoising)
return agg_loss, agg_sample_size, agg_logging_output return agg_loss, agg_sample_size, agg_logging_output
......
...@@ -313,7 +313,7 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -313,7 +313,7 @@ def get_activation_fn(activation: str) -> Callable:
elif activation == 'tanh': elif activation == 'tanh':
return torch.tanh return torch.tanh
else: else:
raise RuntimeError(f"--activation-fn {activation} not supported") raise RuntimeError("--activation-fn {} not supported".format(activation))
def get_available_activation_fns() -> List: def get_available_activation_fns() -> List:
......
...@@ -35,6 +35,7 @@ setup( ...@@ -35,6 +35,7 @@ setup(
classifiers=[ classifiers=[
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: BSD License', 'License :: OSI Approved :: BSD License',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
], ],
......
...@@ -303,7 +303,7 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -303,7 +303,7 @@ class TestMaskedLanguageModel(unittest.TestCase):
"--encoder-ffn-embed-dim", "--encoder-ffn-embed-dim",
"32", "32",
"--pretrained-xlm-checkpoint", "--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt", "{}/checkpoint_last.pt".format(data_dir),
"--activation-fn", "--activation-fn",
"gelu", "gelu",
"--max-source-positions", "--max-source-positions",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment