download_model_urls.py 1.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import asyncio
import sys
from pathlib import Path
from time import perf_counter
from urllib.parse import urlsplit

import aiofiles
import aiohttp
from torchvision import models
from tqdm.asyncio import tqdm


async def main(download_root):
    download_root.mkdir(parents=True, exist_ok=True)
    urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))}

    async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
        await tqdm.gather(*[download(download_root, session, url) for url in urls])


async def download(download_root, session, url):
    response = await session.get(url, params=dict(source="ci"))

    assert response.ok

    file_name = Path(urlsplit(url).path).name
    async with aiofiles.open(download_root / file_name, "wb") as f:
        async for data in response.content.iter_any():
            await f.write(data)


if __name__ == "__main__":
    download_root = (
        (Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve()
    )
    print(f"Downloading model weights to {download_root}")
    start = perf_counter()
    asyncio.get_event_loop().run_until_complete(main(download_root))
    stop = perf_counter()
    minutes, seconds = divmod(stop - start, 60)
    print(f"Download took {minutes:2.0f}m {seconds:2.0f}s")