espnet2_task.md 8.34 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Task class and data input system for training
## Task class

In ESpnet1, we have too many duplicated python modules.
One of the big purposes of ESPnet2 is to provide a common interface and
enable us to focus more on the unique parts of each task.

`Task` class is a common system to build training tools for each task,
ASR, TTS, LM, etc. inspired by `Fairseq Task` idea.
To build your task, only you have to do is just inheriting `AbsTask` class:

```python
from espnet2.tasks.abs_task import AbsTask
from espnet2.train.abs_espnet_model import AbsESPnetModel

class NewModel(ESPnetModel):
    def forward(self, input, target):
        (...)
        # loss: The loss of the task. Must be a scalar value.
        # stats: A dict object, used for logging and validation criterion
        # weight: A scalar value that is used for normalization of loss and stats values among each mini-batches.
        #     In many cases, this value should be equal to the mini-batch-size
        return loss, stats, weight

class NewTask(AbsTask):
    @classmethod
    def add_task_arguments(cls, parser):
        parser.add_arguments(...)
        (...)

    @classmethod
    def build_collate_fn(cls, args: argparse.Namespace)
        (...)

    @classmethod
    def build_preprocess_fn(cls, args, train):
        (...)

    @classmethod
    def required_data_names(cls, inference: bool = False):
        (...)

    @classmethod
    def optional_data_names(cls, inference: bool = False):
        (...)

    @classmethod
    def build_model(cls, args):
        return NewModel(...)

if __name__ == "__main__":
    # Start training
    NewTask.main()
```

## Data input system
Espnet2 also provides a command line interface to describe the training corpus.
On the contrary, unlike `fairseq` or training system such as `pytorch-lightning`,
our `Task` class doesn't have an interface for building the dataset explicitly.
This is because we aim at the task related to speech/text only,
so we don't need such general system so far.

The following is an example of the command lint arguments:

```bash
python -m espnet2.bin.asr_train \
  --train_data_path_and_name_and_type=/some/path/tr/wav.scp,speech,sound \
  --train_data_path_and_name_and_type=/some/path/tr/token_int,text,text_int \
  --valid_data_path_and_name_and_type=/some/path/dev/wav.scp,speech,sound \
  --valid_data_path_and_name_and_type=/some/path/dev/token_int,text,text_int
```

First of all, our mini-batch is always a `dict` object:

```python
# In training iteration
for batch in iterator:
    # e.g. batch = {"speech": ..., "text": ...}
    # Forward
    model(**batch)
```

Where the `model` is same as the model built by `Task.build_model()`.

You can flexibly construct this mini-batch object
using `--*_data_path_and_name_and_type`.
`--*_data_path_and_name_and_type` can be repeated as you need and
each `--*_data_path_and_name_and_type` corresponds to an element in the mini-batch.
Also, keep in mind that **there is no distinction between input and target data**.


The argument of `--train_data_path_and_name_and_type`
should be given as three values separated by commas,
like `<file-path>,<key-name>,<file-format>`.

- `key-name` specify the key of dict
- `file-path` is a file/directory path for the data source.
- `file-format` indicates the format of file specified by `file-path`. e.g. `sound`, `kaldi_ark`, or etc.


### `scp` file
You can show the supported file format using `--help` option.

```bash
python -m espnet2.bin.asr_train --help
```

Almost all formats are referred as `scp` file  according to Kaldi-ASR.
`scp` is just a text file which has two columns for each line:
The first indicates the sample id and the second is some value.
e.g. file path, transcription, a sequence of numbers.


- format=npy
    ```
    sample_id_a /some/path/a.npy
    sample_id_b /some/path/b.npy
    ```
- format=sound
    ```
    sample_id_a /some/path/a.flac
    sample_id_b /some/path/a.wav
    ```
- format=kaldi_ark
    ```
    sample_id_a /some/path/a.ark:1234
    sample_id_b /some/path/a.ark:5678
    ```
- format=text_int
    ```
    sample_id_a 10 2 4 4
    sample_id_b 3 2 0 1 6 2
    ```
- format=text
    ```
    sample_id_a hello world
    sample_id_b It is rainy today
    ```


### `required_data_names()` and `optional_data_names()`
Though an arbitrary dictionary can be created by this system,
each task assumes that the specific key is given for a specific purpose.
e.g. ASR Task requires `speech` and `text` keys and
each value is used for input data and target data respectively.
See again the methods of `Task` class:
`required_data_names()` and `optional_data_names()`.


```python
class NewTask(AbsTask):
  @classmethod
  def required_data_names(cls, inference: bool = False):
      if not inference:
          retval = ("input", "target")
      else:
          retval = ("input",)
      return retval

  @classmethod
  def optional_data_names(cls, inference: bool = False):
      retval = ("auxially_feature",)
      return retval
```

`required_data_names()` determines the mandatory data names and `optional_data_names()` gives optional data. It means that the other names are allowed to given by command line arguments.

```bash
# The following is the expected argument
python -m new_task \
  --train_data_path_and_name_and_type=filepath,input,sometype \
  --train_data_path_and_name_and_type=filepath,target,sometype \
  --train_data_path_and_name_and_type=filepath,auxially_feature,sometype
# The following raises an error
python -m new_task \
  --train_data_path_and_name_and_type=filepath,unknown,sometype
```

The intention of this system is just an assertion check, so if feel unnecessary,
you can turn off this checking with `--allow_variable_data_keys true`.

```bash
# Ignore assertion checking for data names
python -m new_task \
  --train_data_path_and_name_and_type=filepath,unknown_name,sometype \
  --allow_variable_data_keys true
```


## Customize `collcate_fn` for PyTorch data loader
`Task` class has a method to customize `collcate_fn`:

```python
class NewTask(AbsTask):
  @classmethod
  def build_collate_fn(cls, args: argparse.Namespace):
    ...
```

`collcate_fn` is an argument of `torch.utils.data.DataLoader` and
it can modify the data which is received from data-loader. e.g.:

```python
def collcate_fn(data):
    # data is a list of the return value of Dataset class:
    modified_data = (...touch data)
    return modified_data

from torch.utils.data import DataLoader
data_loader = DataLoader(dataset, collcate_fn=collcate_fn)
for modified_data in data_loader:
    ...
```

The type of argument is determined by the input `dataset` class and
our dataset is always `espnet2.train.dataset.ESPnetDataset`,
which the return value is a tuple of sample id and a dict of tensor,

```python
batch = ("sample_id", {"speech": tensor, "text": tensor})
```

Therefore, the type is a list of dict of tensor.

```python
data = [
  ("sample_id", {"speech": tensor, "text": tensor}),
  ("sample_id2", {"speech": tensor, "text": tensor}),
  ...
]
```

The return type of collate_fn is supposed to be a tuple of list and a dict of tensor in espnet2,
so the collcate_fn for `Task` must transform the data type to it.

```python
for ids, batch in data_loader:
  model(**batch)
```

We provide common collate_fn and this function can support many cases,
so you might not need to customize it.
This collate_fn is aware of variable sequence features for seq2seq task:

- The first axis of the sequence tensor from dataset must be length axis: e.g. (Length, Dim), (Length, Dim, Dim2), or (Length, ...)
- It's not necessary to make the lengths of each sample unified and they are stacked with zero-padding.
    - The value of padding can be changed.
        ```python
        from espnet2.train.collate_fn import CommonCollateFn
        @classmethod
        def build_collate_fn(cls, args):
            # float_pad_value is used for float-tensor and int_pad_value is used for int-tensor
            return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
        ```
- Tensors which represent the length of each samples are also appended
    ```python
    batch = {"speech": ..., "speech_lengths": ..., "text": ..., "text_lengths": ...}
    ```
- If the feature is not sequential data, this behavior can be disabled.
    ```bash
    python -m new_task --train_data_path_and_name_and_type=filepath,foo,npy
    ```
    ```python
    @classmethod
    def build_collate_fn(cls, args):
        return CommonCollateFn(not_sequence=["foo"])
    ```