custom_directives.py 3.29 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import hashlib
from pathlib import Path
from typing import List
from urllib.parse import quote, urlencode

import requests
from docutils import nodes
from docutils.parsers.rst.directives.images import Image


_THIS_DIR = Path(__file__).parent

# Color palette from PyTorch Developer Day 2021 Presentation Template
YELLOW = "F9DB78"
GREEN = "70AD47"
BLUE = "00B0F0"
PINK = "FF71DA"
ORANGE = "FF8300"
TEAL = "00E5D1"
GRAY = "7F7F7F"


def _get_cache_path(key, ext):
    filename = f"{hashlib.sha256(key).hexdigest()}{ext}"
    cache_dir = _THIS_DIR / "gen_images"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir / filename


def _download(url, path):
    response = requests.get(url)
    response.raise_for_status()
    with open(path, "wb") as file:
        file.write(response.content)


def _fetch_image(url):
    path = _get_cache_path(url.encode("utf-8"), ext=".svg")
    if not path.exists():
        _download(url, path)
    return str(path.relative_to(_THIS_DIR))


class BaseShield(Image):
    def run(self, params, alt, section) -> List[nodes.Node]:
        url = f"https://img.shields.io/static/v1?{urlencode(params, quote_via=quote)}"
        path = _fetch_image(url)
        self.arguments = [path]
        self.options["alt"] = alt
        if "class" not in self.options:
            self.options["class"] = []
        self.options["class"].append("shield-badge")
        self.options["target"] = f"supported_features.html#{section}"
        return super().run()


def _parse_devices(arg: str):
    devices = sorted(arg.strip().split())

    valid_values = {"CPU", "CUDA"}
    if any(val not in valid_values for val in devices):
        raise ValueError(
            f"One or more device values are not valid. The valid values are {valid_values}. Given value: '{arg}'"
        )
    return ", ".join(sorted(devices))


def _parse_properties(arg: str):
    properties = sorted(arg.strip().split())

    valid_values = {"Autograd", "TorchScript"}
    if any(val not in valid_values for val in properties):
        raise ValueError(
            "One or more property values are not valid. "
            f"The valid values are {valid_values}. "
            f"Given value: '{arg}'"
        )
    return ", ".join(sorted(properties))


class SupportedDevices(BaseShield):
    """List the supported devices"""

    required_arguments = 1
    final_argument_whitespace = True

    def run(self) -> List[nodes.Node]:
        devices = _parse_devices(self.arguments[0])
        alt = f"This feature supports the following devices: {devices}"
        params = {
            "label": "Devices",
            "message": devices,
            "labelColor": GRAY,
            "color": BLUE,
            "style": "flat-square",
        }
        return super().run(params, alt, "devices")


class SupportedProperties(BaseShield):
    """List the supported properties"""

    required_arguments = 1
    final_argument_whitespace = True

    def run(self) -> List[nodes.Node]:
        properties = _parse_properties(self.arguments[0])
        alt = f"This API supports the following properties: {properties}"
        params = {
            "label": "Properties",
            "message": properties,
            "labelColor": GRAY,
            "color": GREEN,
            "style": "flat-square",
        }
        return super().run(params, alt, "properties")