overview.rst 2.62 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
Overview
========

Fairseq can be extended through user-supplied `plug-ins
<https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
plug-ins:

- :ref:`Models` define the neural network architecture and encapsulate all of the
  learnable parameters.
- :ref:`Criterions` compute the loss function given the model outputs and targets.
- :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
  Datasets, initializing the Model/Criterion and calculating the loss.
- :ref:`Optimizers` update the Model parameters based on the gradients.
- :ref:`Learning Rate Schedulers` update the learning rate over the course of
  training.

**Training Flow**

Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
fairseq implements the following high-level training flow::

  for epoch in range(num_epochs):
      itr = task.get_batch_iterator(task.dataset('train'))
      for num_updates, batch in enumerate(itr):
Myle Ott's avatar
Myle Ott committed
25
26
          task.train_step(batch, model, criterion, optimizer)
          average_and_clip_gradients()
Myle Ott's avatar
Myle Ott committed
27
28
29
30
          optimizer.step()
          lr_scheduler.step_update(num_updates)
      lr_scheduler.step(epoch)

Myle Ott's avatar
Myle Ott committed
31
where the default implementation for ``task.train_step`` is roughly::
Myle Ott's avatar
Myle Ott committed
32
33
34
35

  def train_step(self, batch, model, criterion, optimizer):
      loss = criterion(model, batch)
      optimizer.backward(loss)
Myle Ott's avatar
Myle Ott committed
36
      return loss
Myle Ott's avatar
Myle Ott committed
37

Myle Ott's avatar
Myle Ott committed
38
39
40
41
42
43
44
45
46
47
48
49
**Registering new plug-ins**

New plug-ins are *registered* through a set of ``@register`` function
decorators, for example::

  @register_model('my_lstm')
  class MyLSTM(FairseqModel):
      (...)

Once registered, new plug-ins can be used with the existing :ref:`Command-line
Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
new plug-ins.
50
51
52

**Loading plug-ins from another directory**

Myle Ott's avatar
Myle Ott committed
53
54
55
56
New plug-ins can be defined in a custom module stored in the user system. In
order to import the module, and make the plugin available to *fairseq*, the
command line supports the ``--user-dir`` flag that can be used to specify a
custom location for additional modules to load into *fairseq*.
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

For example, assuming this directory tree::

  /home/user/my-module/
  └── __init__.py
  
with ``__init__.py``::

  from fairseq.models import register_model_architecture
  from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big

  @register_model_architecture('transformer', 'my_transformer')
  def transformer_mmt_big(args):
      transformer_vaswani_wmt_en_de_big(args)

Myle Ott's avatar
Myle Ott committed
72
it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
73

Myle Ott's avatar
Myle Ott committed
74
  fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation