"vscode:/vscode.git/clone" did not exist on "ae50ad91ea2fedb64ecd2e7c8e2d0d4778dc03aa"
Commit dc667ce1 authored by thomwolf's avatar thomwolf
Browse files

double check cc @LysandreJik

parent 3fd71c44
...@@ -75,7 +75,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -75,7 +75,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
n_batch = len(dataset) n_batch = len(dataset)
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64) input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64) mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64) lm_labels = np.full((n_batch, 2, input_len), fill_value=-100, dtype=np.int64)
mc_labels = np.zeros((n_batch,), dtype=np.int64) mc_labels = np.zeros((n_batch,), dtype=np.int64)
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
......
...@@ -186,7 +186,7 @@ class Distiller: ...@@ -186,7 +186,7 @@ class Distiller:
------- -------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict. mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
""" """
token_ids, lengths = batch token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
...@@ -246,7 +246,7 @@ class Distiller: ...@@ -246,7 +246,7 @@ class Distiller:
------- -------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict. clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
""" """
token_ids, lengths = batch token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment