video_writer.py 5.77 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

7
8
# pyre-unsafe

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
9
10
import os
import shutil
11
import subprocess
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
12
13
14
15
16
17
18
19
20
import tempfile
import warnings
from typing import Optional, Tuple, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

21
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
23
24
25
26
27
28
29
30
31
32
33

matplotlib.use("Agg")


class VideoWriter:
    """
    A class for exporting videos.
    """

    def __init__(
        self,
        cache_dir: Optional[str] = None,
34
        ffmpeg_bin: str = _DEFAULT_FFMPEG,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
35
36
37
38
39
        out_path: str = "/tmp/video.mp4",
        fps: int = 20,
        output_format: str = "visdom",
        rmdir_allowed: bool = False,
        **kwargs,
40
    ) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        """
        Args:
            cache_dir: A directory for storing the video frames. If `None`,
                a temporary directory will be used.
            ffmpeg_bin: The path to an `ffmpeg` executable.
            out_path: The path to the output video.
            fps: The speed of the generated video in frames-per-second.
            output_format: Format of the output video. Currently only `"visdom"`
                is supported.
            rmdir_allowed: If `True` delete and create `cache_dir` in case
                it is not empty.
        """
        self.rmdir_allowed = rmdir_allowed
        self.output_format = output_format
        self.fps = fps
        self.out_path = out_path
        self.cache_dir = cache_dir
        self.ffmpeg_bin = ffmpeg_bin
        self.frames = []
        self.regexp = "frame_%08d.png"
        self.frame_num = 0

        if self.cache_dir is not None:
            self.tmp_dir = None
            if os.path.isdir(self.cache_dir):
                if rmdir_allowed:
                    shutil.rmtree(self.cache_dir)
                else:
                    warnings.warn(
                        f"Warning: cache directory not empty ({self.cache_dir})."
                    )
            os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self.tmp_dir = tempfile.TemporaryDirectory()
            self.cache_dir = self.tmp_dir.name

    def write_frame(
        self,
        frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
        resize: Optional[Union[float, Tuple[int, int]]] = None,
81
    ) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
82
83
84
85
86
87
88
89
90
        """
        Write a frame to the video.

        Args:
            frame: An object containing the frame image.
            resize: Either a floating defining the image rescaling factor
                or a 2-tuple defining the size of the output image.
        """

91
92
        # pyre-fixme[6]: For 1st argument expected `Union[PathLike[str], str]` but
        #  got `Optional[str]`.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        outfile = os.path.join(self.cache_dir, self.regexp % self.frame_num)

        if isinstance(frame, matplotlib.figure.Figure):
            plt.savefig(outfile)
            im = Image.open(outfile)
        elif isinstance(frame, np.ndarray):
            if frame.dtype in (np.float64, np.float32, float):
                frame = (np.transpose(frame, (1, 2, 0)) * 255.0).astype(np.uint8)
            im = Image.fromarray(frame)
        elif isinstance(frame, Image.Image):
            im = frame
        elif isinstance(frame, str):
            im = Image.open(frame).convert("RGB")
        else:
            raise ValueError("Cant convert type %s" % str(type(frame)))

        if im is not None:
            if resize is not None:
                if isinstance(resize, float):
                    resize = [int(resize * s) for s in im.size]
            else:
                resize = im.size
            # make sure size is divisible by 2
            resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
117
            # pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
118
119
120
121
122
123
            im = im.resize(resize, Image.ANTIALIAS)
            im.save(outfile)

        self.frames.append(outfile)
        self.frame_num += 1

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
124
    def get_video(self, quiet: bool = True) -> str:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
125
126
127
128
129
130
131
        """
        Generate the video from the written frames.

        Args:
            quiet: If `True`, suppresses logging messages.

        Returns:
132
133
            video_path: The path to the generated video if any frames were added.
                Otherwise returns an empty string.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
134
        """
135
136
        if self.frame_num == 0:
            return ""
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
137

138
139
        # pyre-fixme[6]: For 1st argument expected `Union[PathLike[str], str]` but
        #  got `Optional[str]`.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
140
141
        regexp = os.path.join(self.cache_dir, self.regexp)

142
143
144
145
        if shutil.which(self.ffmpeg_bin) is None:
            raise ValueError(
                f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
                + "Please set FFMPEG in the environment or ffmpeg_bin on this class."
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
146
147
            )

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        if self.output_format == "visdom":  # works for ppt too
            args = [
                self.ffmpeg_bin,
                "-r",
                str(self.fps),
                "-i",
                regexp,
                "-vcodec",
                "h264",
                "-f",
                "mp4",
                "-y",
                "-crf",
                "18",
                "-b",
                "2000k",
                "-pix_fmt",
                "yuv420p",
                self.out_path,
            ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
168
169
170
171
172
173
            if quiet:
                subprocess.check_call(
                    args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
                )
            else:
                subprocess.check_call(args)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
174
        else:
175
            raise ValueError("no such output type %s" % str(self.output_format))
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
176
177
178

        return self.out_path

179
    def __del__(self) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
180
181
        if self.tmp_dir is not None:
            self.tmp_dir.cleanup()