"vscode:/vscode.git/clone" did not exist on "66fc8d25a5de43c08baeea8b22b9bcf57346c325"
Commit 3fd71c44 authored by LysandreJik's avatar LysandreJik
Browse files

Update example scripts

parent b72f9d34
......@@ -112,7 +112,7 @@ class Distiller:
self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if self.alpha_mse > 0.:
self.mse_loss_fct = nn.MSELoss(reduction='sum')
if self.alpha_cos > 0.:
......@@ -224,7 +224,7 @@ class Distiller:
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
......@@ -254,7 +254,7 @@ class Distiller:
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
......
......@@ -94,7 +94,7 @@ def convert_examples_to_features(examples,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
pad_token_label_id=-1,
pad_token_label_id=-100,
sequence_a_segment_id=0,
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment