Unverified Commit 143d078b authored by deekay42's avatar deekay42 Committed by GitHub
Browse files

Adding GPU acceleration to encode_jpeg (#8391)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent f96c42fc
import os
import platform
import statistics
import torch
import torch.utils.benchmark as benchmark
import torchvision
def print_machine_specs():
print("Processor:", platform.processor())
print("Platform:", platform.platform())
print("Logical CPUs:", os.cpu_count())
print(f"\nCUDA device: {torch.cuda.get_device_name()}")
print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
def get_data():
transform = torchvision.transforms.Compose(
[
torchvision.transforms.PILToTensor(),
]
)
path = os.path.join(os.getcwd(), "data")
testset = torchvision.datasets.Places365(
root="./data", download=not os.path.exists(path), transform=transform, split="val"
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch]
)
return next(iter(testloader))
def run_benchmark(batch):
results = []
for device in ["cpu", "cuda"]:
batch_device = [t.to(device=device) for t in batch]
for size in [1, 100, 1000]:
for num_threads in [1, 12, 24]:
for stmt, strat in zip(
[
"[torchvision.io.encode_jpeg(img) for img in batch_input]",
"torchvision.io.encode_jpeg(batch_input)",
],
["unfused", "fused"],
):
batch_input = batch_device[:size]
t = benchmark.Timer(
stmt=stmt,
setup="import torchvision",
globals={"batch_input": batch_input},
label="Image Encoding",
sub_label=f"{device.upper()} ({strat}): {stmt}",
description=f"{size} images",
num_threads=num_threads,
)
results.append(t.blocked_autorange())
compare = benchmark.Compare(results)
compare.print()
if __name__ == "__main__":
print_machine_specs()
batch = get_data()
mean_h, mean_w = statistics.mean(t.shape[-2] for t in batch), statistics.mean(t.shape[-1] for t in batch)
print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}")
run_benchmark(batch)
import concurrent.futures
import glob
import io
import os
......@@ -10,7 +11,7 @@ import pytest
import requests
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, IN_OSS_CI, needs_cuda
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
......@@ -508,6 +509,200 @@ def test_encode_jpeg(img_path, scripted):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
@needs_cuda
def test_encode_jpeg_cuda_device_param():
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
data = read_image(path)
current_device = torch.cuda.current_device()
current_stream = torch.cuda.current_stream()
num_devices = torch.cuda.device_count()
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
results = []
for device in devices:
print(f"python: device: {device}")
results.append(encode_jpeg(data.to(device=device)))
assert len(results) == len(devices)
for result in results:
assert torch.all(result.cpu() == results[0].cpu())
assert current_device == torch.cuda.current_device()
assert current_stream == torch.cuda.current_stream()
@needs_cuda
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("contiguous", (False, True))
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
decoded_image_tv = read_image(img_path)
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
if "cmyk" in img_path:
pytest.xfail("Encoding a CMYK jpeg isn't supported")
if decoded_image_tv.shape[0] == 1:
pytest.xfail("Decoding a grayscale jpeg isn't supported")
# For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
if contiguous:
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
else:
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())
# the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
# instead, we re-decode the encoded image and compare to the original
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
assert abs_mean_diff < 3
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scripted", (True, False))
@pytest.mark.parametrize("contiguous", (True, False))
def test_encode_jpegs_batch(scripted, contiguous, device):
if device == "cpu" and IS_MACOS:
pytest.skip("https://github.com/pytorch/vision/issues/8031")
decoded_images_tv = []
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
if "cmyk" in jpeg_path:
continue
decoded_image = read_image(jpeg_path)
if decoded_image.shape[0] == 1:
continue
if contiguous:
decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
else:
decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
decoded_images_tv.append(decoded_image)
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]
for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
c, h, w = original.shape
abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
assert abs_mean_diff < 3
# test multithreaded decoding
# in the current version we prevent this by using a lock but we still want to test it
num_workers = 10
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
encoded_images_threaded = [future.result() for future in futures]
assert len(encoded_images_threaded) == num_workers
for encoded_images in encoded_images_threaded:
assert len(decoded_images_tv_device) == len(encoded_images)
for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
# make sure all the threads produce identical outputs
assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])
# make sure the outputs are identical or close enough to baseline
decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3
@needs_cuda
def test_single_encode_jpeg_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))
@needs_cuda
def test_batch_encode_jpegs_cuda_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
]
)
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
]
)
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
]
)
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
]
)
with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
]
)
with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
]
)
if torch.cuda.device_count() >= 2:
with pytest.raises(
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
):
encode_jpeg(
[
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
]
)
with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
encode_jpeg([])
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
@pytest.mark.parametrize(
"img_path",
......
#include "decode_jpeg_cuda.h"
#include "encode_decode_jpegs_cuda.h"
#include <ATen/ATen.h>
......
......@@ -2,6 +2,7 @@
#include <torch/types.h>
#include "../image_read_mode.h"
#include "encode_jpegs_cuda.h"
namespace vision {
namespace image {
......@@ -11,5 +12,9 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda(
ImageReadMode mode,
torch::Device device);
C10_EXPORT std::vector<torch::Tensor> encode_jpegs_cuda(
const std::vector<torch::Tensor>& decoded_images,
const int64_t quality);
} // namespace image
} // namespace vision
#include "encode_jpegs_cuda.h"
#if !NVJPEG_FOUND
namespace vision {
namespace image {
std::vector<torch::Tensor> encode_jpegs_cuda(
const std::vector<torch::Tensor>& decoded_images,
const int64_t quality) {
TORCH_CHECK(
false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support");
}
} // namespace image
} // namespace vision
#else
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <cuda_runtime.h>
#include <torch/nn/functional.h>
#include <iostream>
#include <memory>
#include <string>
#include "c10/core/ScalarType.h"
namespace vision {
namespace image {
// We use global variables to cache the encoder and decoder instances and
// reuse them across calls to the corresponding pytorch functions
std::mutex encoderMutex;
std::unique_ptr<CUDAJpegEncoder> cudaJpegEncoder;
std::vector<torch::Tensor> encode_jpegs_cuda(
const std::vector<torch::Tensor>& decoded_images,
const int64_t quality) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda");
// Some nvjpeg structures are not thread safe so we're keeping it single
// threaded for now. In the future this may be an opportunity to unlock
// further speedups
std::lock_guard<std::mutex> lock(encoderMutex);
TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list");
torch::Device device = decoded_images[0].device();
at::cuda::CUDAGuard device_guard(device);
// lazy init of the encoder class
// the encoder object holds on to a lot of state and is expensive to create,
// so we reuse it across calls. NB: the cached structures are device specific
// and cannot be reused across devices
if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) {
if (cudaJpegEncoder != nullptr)
delete cudaJpegEncoder.release();
cudaJpegEncoder = std::make_unique<CUDAJpegEncoder>(device);
// Unfortunately, we cannot rely on the smart pointer releasing the encoder
// object correctly upon program exit. This is because, when cudaJpegEncoder
// gets destroyed, the CUDA runtime may already be shut down, rendering all
// destroy* calls in the encoder destructor invalid. Instead, we use an
// atexit hook which executes after main() finishes, but hopefully before
// CUDA shuts down when the program exits. If CUDA is already shut down the
// destructor will detect this and will not attempt to destroy any encoder
// structures.
std::atexit([]() { delete cudaJpegEncoder.release(); });
}
std::vector<torch::Tensor> contig_images;
contig_images.reserve(decoded_images.size());
for (const auto& image : decoded_images) {
TORCH_CHECK(
image.dtype() == torch::kU8, "Input tensor dtype should be uint8");
TORCH_CHECK(
image.device() == device,
"All input tensors must be on the same CUDA device when encoding with nvjpeg")
TORCH_CHECK(
image.dim() == 3 && image.numel() > 0,
"Input data should be a 3-dimensional tensor");
TORCH_CHECK(
image.size(0) == 3,
"The number of channels should be 3, got: ",
image.size(0));
// nvjpeg requires images to be contiguous
if (image.is_contiguous()) {
contig_images.push_back(image);
} else {
contig_images.push_back(image.contiguous());
}
}
cudaJpegEncoder->set_quality(quality);
std::vector<torch::Tensor> encoded_images;
at::cuda::CUDAEvent event;
event.record(cudaJpegEncoder->stream);
for (const auto& image : contig_images) {
auto encoded_image = cudaJpegEncoder->encode_jpeg(image);
encoded_images.push_back(encoded_image);
}
// We use a dedicated stream to do the encoding and even though the results
// may be ready on that stream we cannot assume that they are also available
// on the current stream of the calling context when this function returns. We
// use a blocking event to ensure that this is indeed the case. Crucially, we
// do not want to block the host at this particular point
// (which is what cudaStreamSynchronize would do.) Events allow us to
// synchronize the streams without blocking the host.
event.block(at::cuda::getCurrentCUDAStream(
cudaJpegEncoder->original_device.has_index()
? cudaJpegEncoder->original_device.index()
: 0));
return encoded_images;
}
CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
: original_device{torch::kCUDA, torch::cuda::current_device()},
target_device{target_device},
stream{
target_device.has_index()
? at::cuda::getStreamFromPool(false, target_device.index())
: at::cuda::getStreamFromPool(false)} {
nvjpegStatus_t status;
status = nvjpegCreateSimple(&nvjpeg_handle);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS,
"Failed to create nvjpeg handle: ",
status);
status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS,
"Failed to create nvjpeg encoder state: ",
status);
status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS,
"Failed to create nvjpeg encoder params: ",
status);
}
CUDAJpegEncoder::~CUDAJpegEncoder() {
/*
The below code works on Mac and Linux, but fails on Windows.
This is because on Windows, the atexit hook which calls this
destructor executes after cuda is already shut down causing SIGSEGV.
We do not have a solution to this problem at the moment, so we'll
just leak the libnvjpeg & cuda variables for the time being and hope
that the CUDA runtime handles cleanup for us.
Please send a PR if you have a solution for this problem.
*/
// // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is
// still
// // initialized. If it is not, we can skip the rest of this function as it
// is
// // unsafe to execute.
// int deviceCount = 0;
// cudaError_t error = cudaGetDeviceCount(&deviceCount);
// if (error != cudaSuccess)
// return; // CUDA runtime has already shut down. There's nothing we can do
// // now.
// nvjpegStatus_t status;
// status = nvjpegEncoderParamsDestroy(nv_enc_params);
// TORCH_CHECK(
// status == NVJPEG_STATUS_SUCCESS,
// "Failed to destroy nvjpeg encoder params: ",
// status);
// status = nvjpegEncoderStateDestroy(nv_enc_state);
// TORCH_CHECK(
// status == NVJPEG_STATUS_SUCCESS,
// "Failed to destroy nvjpeg encoder state: ",
// status);
// cudaStreamSynchronize(stream);
// status = nvjpegDestroy(nvjpeg_handle);
// TORCH_CHECK(
// status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status);
}
torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
int channels = src_image.size(0);
int height = src_image.size(1);
int width = src_image.size(2);
nvjpegStatus_t status;
cudaError_t cudaStatus;
status = nvjpegEncoderParamsSetSamplingFactors(
nv_enc_params, NVJPEG_CSS_444, stream);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS,
"Failed to set nvjpeg encoder params sampling factors: ",
status);
nvjpegImage_t target_image;
for (int c = 0; c < channels; c++) {
target_image.channel[c] = src_image[c].data_ptr<uint8_t>();
// this is why we need contiguous tensors
target_image.pitch[c] = width;
}
for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) {
target_image.channel[c] = nullptr;
target_image.pitch[c] = 0;
}
// Encode the image
status = nvjpegEncodeImage(
nvjpeg_handle,
nv_enc_state,
nv_enc_params,
&target_image,
NVJPEG_INPUT_RGB,
width,
height,
stream);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status);
// Retrieve length of the encoded image
size_t length;
status = nvjpegEncodeRetrieveBitstreamDevice(
nvjpeg_handle, nv_enc_state, NULL, &length, stream);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS,
"Failed to retrieve encoded image stream state: ",
status);
// Synchronize the stream to ensure that the encoded image is ready
cudaStatus = cudaStreamSynchronize(stream);
TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus);
// Reserve buffer for the encoded image
torch::Tensor encoded_image = torch::empty(
{static_cast<long>(length)},
torch::TensorOptions()
.dtype(torch::kByte)
.layout(torch::kStrided)
.device(target_device)
.requires_grad(false));
cudaStatus = cudaStreamSynchronize(stream);
TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus);
// Retrieve the encoded image
status = nvjpegEncodeRetrieveBitstreamDevice(
nvjpeg_handle,
nv_enc_state,
encoded_image.data_ptr<uint8_t>(),
&length,
0);
TORCH_CHECK(
status == NVJPEG_STATUS_SUCCESS,
"Failed to retrieve encoded image: ",
status);
return encoded_image;
}
void CUDAJpegEncoder::set_quality(const int64_t quality) {
nvjpegStatus_t paramsQualityStatus =
nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream);
TORCH_CHECK(
paramsQualityStatus == NVJPEG_STATUS_SUCCESS,
"Failed to set nvjpeg encoder params quality: ",
paramsQualityStatus);
}
} // namespace image
} // namespace vision
#endif // NVJPEG_FOUND
#pragma once
#include <torch/types.h>
#include <vector>
#if NVJPEG_FOUND
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <nvjpeg.h>
namespace vision {
namespace image {
class CUDAJpegEncoder {
public:
CUDAJpegEncoder(const torch::Device& device);
~CUDAJpegEncoder();
torch::Tensor encode_jpeg(const torch::Tensor& src_image);
void set_quality(const int64_t quality);
const torch::Device original_device;
const torch::Device target_device;
const c10::cuda::CUDAStream stream;
protected:
nvjpegEncoderState_t nv_enc_state;
nvjpegEncoderParams_t nv_enc_params;
nvjpegHandle_t nvjpeg_handle;
};
} // namespace image
} // namespace vision
#endif
......@@ -27,6 +27,7 @@ static auto registry =
.op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda)
.op("image::encode_jpegs_cuda", &encode_jpegs_cuda)
.op("image::_jpeg_version", &_jpeg_version)
.op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo);
......
......@@ -7,4 +7,4 @@
#include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h"
#include "cpu/read_write_file.h"
#include "cuda/decode_jpeg_cuda.h"
#include "cuda/encode_decode_jpegs_cuda.h"
from enum import Enum
from typing import List, Union
from warnings import warn
import torch
......@@ -68,7 +69,9 @@ def write_file(filename: str, data: torch.Tensor) -> None:
def decode_png(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
......@@ -180,28 +183,42 @@ def decode_jpeg(
return output
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
def encode_jpeg(
input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding JPEG file.
Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents
of the corresponding JPEG file(s).
.. note::
Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
For CPU tensors the performance is equivalent.
Args:
input (Tensor[channels, image_height, image_width])): int8 image tensor of
``c`` channels, where ``c`` must be 1 or 3.
quality (int): Quality of the resulting JPEG file, it must be a number between
input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
(list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
quality (int): Quality of the resulting JPEG file(s). Must be a number between
1 and 100. Default: 75
Returns:
output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
JPEG file.
output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(encode_jpeg)
if quality < 1 or quality > 100:
raise ValueError("Image quality should be a positive number between 1 and 100")
output = torch.ops.image.encode_jpeg(input, quality)
return output
if isinstance(input, list):
if not input:
raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
if input[0].device.type == "cuda":
return torch.ops.image.encode_jpegs_cuda(input, quality)
else:
return [torch.ops.image.encode_jpeg(image, quality) for image in input]
else: # single input tensor
if input.device.type == "cuda":
return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
else:
return torch.ops.image.encode_jpeg(input, quality)
def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
......@@ -218,11 +235,14 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_jpeg)
output = encode_jpeg(input, quality)
assert isinstance(output, torch.Tensor) # Needed for torchscript
write_file(filename, output)
def decode_image(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
......@@ -251,7 +271,9 @@ def decode_image(
def read_image(
path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
path: str,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
......
......@@ -78,9 +78,14 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
if image.shape[0] == 0: # degenerate
return image.reshape(original_shape).clone()
image = [decode_jpeg(encode_jpeg(image[i], quality=quality)) for i in range(image.shape[0])]
image = torch.stack(image, dim=0).view(original_shape)
return image
images = []
for i in range(image.shape[0]):
encoded_image = encode_jpeg(image[i], quality=quality)
assert isinstance(encoded_image, torch.Tensor) # For torchscript
images.append(decode_jpeg(encoded_image))
images = torch.stack(images, dim=0).view(original_shape)
return images
@_register_kernel_internal(jpeg, tv_tensors.Video)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment