"examples/pytorch/multigpu/multi_gpu_link_prediction.py" did not exist on "44638b93364efcf74ed9fe99eba45c73a0196799"
Commit 3658fa32 authored by Jay Mahadeokar's avatar Jay Mahadeokar Committed by Facebook Github Bot
Browse files

aligned training task and CE related changes

Summary:
This diff adds:

1. Aligned training task specifically for doing cross entropy criterion training using prod data and prod like models
2. Few changes to correctly register the task and criterions.
3. Changes to trainer code for propogating accuracy metrics which we care about for training.

Couple of things are hacky right now:
- The reporting is not modular (this needs to be thought about in general for fairseq).

- The get dummy batch could be specific to task instead of specific for dataset.

Reviewed By: myleott

Differential Revision: D14670482

fbshipit-source-id: dc077247b2ae9d26a8e842a386ec5faa5771e836
parent 3a64aced
......@@ -271,6 +271,10 @@ class Trainer(object):
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
)
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
if 'train_acc' in self.meters:
self.meters['train_acc'].update(
logging_output.get('acc', 0), sample_size)
if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
except OverflowError as e:
......@@ -340,6 +344,10 @@ class Trainer(object):
# update meters for validation
ntokens = logging_output.get('ntokens', 0)
self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
if 'valid_acc' in self.meters:
self.meters['valid_acc'].update(
logging_output.get('acc', 0), sample_size)
if 'nll_loss' in logging_output:
self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment