Unverified Commit 0486ccdd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

small improvements (#10773)

parent d7e0d59b
......@@ -162,7 +162,7 @@ class Wav2Vec2ModelTester:
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
......@@ -171,7 +171,7 @@ class Wav2Vec2ModelTester:
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment