Commit e3a40d9d authored by Changhan Wang's avatar Changhan Wang Committed by Facebook Github Bot
Browse files

fix libnat imports

Summary: Bring back the changes in D17661768

Reviewed By: ailzhang

Differential Revision: D17920299

fbshipit-source-id: be3f93a044a8710c8b475012c39e36a3e6507fad
parent d80ad54f
......@@ -6,7 +6,6 @@
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import libnat
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import (
......@@ -52,6 +51,13 @@ neg_scorer = NegativeDistanceScore()
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
B = in_tokens.size(0)
T = in_tokens.size(1)
V = vocab_size
......
......@@ -5,7 +5,6 @@
import torch
import torch.nn.functional as F
from fairseq import libnat
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
......@@ -19,6 +18,13 @@ from fairseq.modules.transformer_sentence_encoder import init_bert_params
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
......@@ -61,6 +67,13 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
def _get_del_targets(in_tokens, out_tokens, padding_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
......@@ -87,6 +100,13 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment