Commit 27a7ad86 authored by luopl's avatar luopl
Browse files

update to v0.9.1

parent 731cf9b8
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
...@@ -34,27 +36,24 @@ def _encode_unsupervised_example( ...@@ -34,27 +36,24 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
...@@ -67,24 +66,21 @@ def preprocess_unsupervised_dataset( ...@@ -67,24 +66,21 @@ def preprocess_unsupervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None: for i in range(len(examples["_prompt"])):
model_inputs["pixel_values"] = [] if len(examples["_prompt"][i]) % 2 != 1:
if hasattr(processor, "image_seq_length"): # paligemma models logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
...@@ -93,10 +89,8 @@ def preprocess_unsupervised_dataset( ...@@ -93,10 +89,8 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment