Unverified Commit 315e6740 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix tests hub failure (#15580)

* Expose hub test problem

* Fix tests
parent b1ba03e0
...@@ -21,6 +21,7 @@ import random ...@@ -21,6 +21,7 @@ import random
import re import re
import subprocess import subprocess
import tempfile import tempfile
import time
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
...@@ -1544,12 +1545,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1544,12 +1545,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
) )
trainer.train() trainer.train()
# Wait for the async pushes to be finished
while trainer.push_in_progress is not None and not trainer.push_in_progress.is_done:
time.sleep(0.5)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token) _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir) commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)] self.assertIn("initial commit", commits)
expected_commits.append("initial commit") # We can't test that epoch 2 and 3 are in the commits without being flaky as those might be skipped if
self.assertListEqual(commits, expected_commits) # the push for epoch 1 wasn't finished at the time.
self.assertIn("Training in progress, epoch 1", commits)
def test_push_to_hub_with_saves_each_n_steps(self): def test_push_to_hub_with_saves_each_n_steps(self):
num_gpus = max(1, get_gpu_count()) num_gpus = max(1, get_gpu_count())
...@@ -1566,13 +1572,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1566,13 +1572,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
) )
trainer.train() trainer.train()
# Wait for the async pushes to be finished
while trainer.push_in_progress is not None and not trainer.push_in_progress.is_done:
time.sleep(0.5)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token) _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir) commits = self.get_commit_history(tmp_dir)
total_steps = 20 // num_gpus self.assertIn("initial commit", commits)
expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)] # We can't test that epoch 2 and 3 are in the commits without being flaky as those might be skipped if
expected_commits.append("initial commit") # the push for epoch 1 wasn't finished at the time.
self.assertListEqual(commits, expected_commits) self.assertIn("Training in progress, step 5", commits)
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment