Commit c1951aa2 authored by alexeib's avatar alexeib Committed by Facebook Github Bot
Browse files

add missing colorize dataset

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/851

Differential Revision: D17145769

Pulled By: alexeib

fbshipit-source-id: 9dd26799d044ae5386e8204a129b5e3fc66d6e85
parent 4a7cd582
...@@ -11,6 +11,7 @@ from .base_wrapper_dataset import BaseWrapperDataset ...@@ -11,6 +11,7 @@ from .base_wrapper_dataset import BaseWrapperDataset
from .audio.raw_audio_dataset import FileAudioDataset from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset from .concat_sentences_dataset import ConcatSentencesDataset
from .id_dataset import IdDataset from .id_dataset import IdDataset
...@@ -51,6 +52,7 @@ from .iterators import ( ...@@ -51,6 +52,7 @@ from .iterators import (
__all__ = [ __all__ = [
'BacktranslationDataset', 'BacktranslationDataset',
'BaseWrapperDataset', 'BaseWrapperDataset',
'ColorizeDataset',
'ConcatDataset', 'ConcatDataset',
'ConcatSentencesDataset', 'ConcatSentencesDataset',
'CountingIterator', 'CountingIterator',
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class ColorizeDataset(BaseWrapperDataset):
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
def __init__(self, dataset, color_getter):
super().__init__(dataset)
self.color_getter = color_getter
def collater(self, samples):
base_collate = super().collater(samples)
if len(base_collate) > 0:
base_collate["net_input"]["colors"] = torch.tensor(
list(self.color_getter(self.dataset, s["id"]) for s in samples),
dtype=torch.long,
)
return base_collate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment