Unverified Commit 36ee1283 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `WhisperModelTest` (#21883)



* force on the same device

* fix tests

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4edfd2d4
...@@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
# Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222)
model_split_percents = [0.8, 0.9]
input_name = "input_features" input_name = "input_features"
...@@ -727,7 +729,17 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -727,7 +729,17 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
input_features = inputs["input_features"] input_features = inputs["input_features"]
decoder_input_ids = inputs["decoder_input_ids"] decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"] decoder_attention_mask = inputs["decoder_attention_mask"]
traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask)) # prepare `attention_mask` with shape (batch_size, sequence_length)
attention_mask = torch.ones(
input_features.shape[0],
input_features.shape[-1],
device=input_features.device,
dtype=input_features.dtype,
)
traced_model = torch.jit.trace(
model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask)
)
except RuntimeError: except RuntimeError:
self.fail("Couldn't trace module.") self.fail("Couldn't trace module.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment