Commit 9770f367 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix PyTorch deprecation warnings

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/618

Differential Revision: D15552599

Pulled By: myleott

fbshipit-source-id: 2192a30a9c5af31b954a3a1716166dd6ba27b23a
parent 47313d85
...@@ -57,8 +57,12 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -57,8 +57,12 @@ class AdaptiveLoss(FairseqCriterion):
for i in range(len(target)): for i in range(len(target)):
if target[i] is not None: if target[i] is not None:
assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1)) assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1))
loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx, loss += F.cross_entropy(
reduce=reduce) logits[i],
target[i],
ignore_index=self.padding_idx,
reduction='sum' if reduce else 'none',
)
orig = utils.strip_pad(orig_target, self.padding_idx) orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel() ntokens = orig.numel()
......
...@@ -42,8 +42,12 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -42,8 +42,12 @@ class CrossEntropyCriterion(FairseqCriterion):
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1) target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, loss = F.nll_loss(
reduce=reduce) lprobs,
target,
ignore_index=self.padding_idx,
reduction='sum' if reduce else 'none',
)
return loss, loss return loss, loss
@staticmethod @staticmethod
......
...@@ -311,7 +311,7 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -311,7 +311,7 @@ def get_activation_fn(activation: str) -> Callable:
elif activation == 'gelu_accurate': elif activation == 'gelu_accurate':
return gelu_accurate return gelu_accurate
elif activation == 'tanh': elif activation == 'tanh':
return F.tanh return torch.tanh
else: else:
raise RuntimeError(f"--activation-fn {activation} not supported") raise RuntimeError(f"--activation-fn {activation} not supported")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment