Unverified Commit acc439ba authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Ci-jukebox (#20613)



* fix cuda OOM by using single Prior

* only send to device when used

* use custom model

* Skip the big slow test

* Update tests/models/jukebox/test_modeling_jukebox.py
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>
parent 9b14c1b6
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
from unittest import skip
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow from transformers.testing_utils import require_torch, slow
...@@ -311,6 +312,7 @@ class Jukebox5bModelTester(unittest.TestCase): ...@@ -311,6 +312,7 @@ class Jukebox5bModelTester(unittest.TestCase):
torch.testing.assert_allclose(zs[2][0], torch.tensor(self.EXPECTED_OUTPUT_0)) torch.testing.assert_allclose(zs[2][0], torch.tensor(self.EXPECTED_OUTPUT_0))
@slow @slow
@skip("Not enough GPU memory on CI runners")
def test_slow_sampling(self): def test_slow_sampling(self):
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
labels = [i.cuda() for i in self.prepare_inputs(self.model_id)] labels = [i.cuda() for i in self.prepare_inputs(self.model_id)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment