Unverified Commit 4491ca2e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added support for CMYK in decode_jpeg (#7741)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent f514ab64
...@@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode): ...@@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
with Image.open(img_path) as img: with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK" is_cmyk = img.mode == "CMYK"
if pil_mode is not None: if pil_mode is not None:
if is_cmyk:
# libjpeg does not support the conversion
pytest.xfail("Decoding a CMYK jpeg isn't supported")
img = img.convert(pil_mode) img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img)) img_pil = torch.from_numpy(np.array(img))
if is_cmyk: if is_cmyk and mode == ImageReadMode.UNCHANGED:
# flip the colors to match libjpeg # flip the colors to match libjpeg
img_pil = 255 - img_pil img_pil = 255 - img_pil
......
...@@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr( ...@@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data; src->pub.next_input_byte = src->data;
} }
inline unsigned char clamped_cmyk_rgb_convert(
unsigned char k,
unsigned char cmy) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
int v = k * cmy + 128;
v = ((v >> 8) + v) >> 8;
return std::clamp(k - v, 0, 255);
}
void convert_line_cmyk_to_rgb(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* rgb_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];
rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
}
}
inline unsigned char rgb_to_gray(int r, int g, int b) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}
void convert_line_cmyk_to_gray(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* gray_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];
int r = clamped_cmyk_rgb_convert(k, 255 - c);
int g = clamped_cmyk_rgb_convert(k, 255 - m);
int b = clamped_cmyk_rgb_convert(k, 255 - y);
gray_line[i] = rgb_to_gray(r, g, b);
}
}
} // namespace } // namespace
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
...@@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { ...@@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_read_header(&cinfo, TRUE); jpeg_read_header(&cinfo, TRUE);
int channels = cinfo.num_components; int channels = cinfo.num_components;
bool cmyk_to_rgb_or_gray = false;
if (mode != IMAGE_READ_MODE_UNCHANGED) { if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) { switch (mode) {
case IMAGE_READ_MODE_GRAY: case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) { if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_GRAYSCALE; cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
} }
channels = 1;
break; break;
case IMAGE_READ_MODE_RGB: case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) { if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_RGB; cinfo.out_color_space = JCS_RGB;
channels = 3;
} }
channels = 3;
break; break;
/* /*
* Libjpeg does not support converting from CMYK to grayscale etc. There * Libjpeg does not support converting from CMYK to grayscale etc. There
...@@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { ...@@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
auto tensor = auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>(); auto ptr = tensor.data_ptr<uint8_t>();
torch::Tensor cmyk_line_tensor;
if (cmyk_to_rgb_or_gray) {
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
}
while (cinfo.output_scanline < cinfo.output_height) { while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines. /* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for * Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient. * more than one scanline at a time if that's more convenient.
*/ */
jpeg_read_scanlines(&cinfo, &ptr, 1); if (cmyk_to_rgb_or_gray) {
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
if (channels == 3) {
convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
} else if (channels == 1) {
convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
}
} else {
jpeg_read_scanlines(&cinfo, &ptr, 1);
}
ptr += stride; ptr += stride;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment