Commit 27a7ad86 authored by luopl's avatar luopl
Browse files

update to v0.9.1

parent 731cf9b8
...@@ -16,6 +16,7 @@ import base64 ...@@ -16,6 +16,7 @@ import base64
import io import io
import json import json
import os import os
import re
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
...@@ -51,9 +52,8 @@ if is_requests_available(): ...@@ -51,9 +52,8 @@ if is_requests_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
...@@ -69,7 +69,7 @@ ROLE_MAPPING = { ...@@ -69,7 +69,7 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0: if len(request.messages) == 0:
...@@ -104,15 +104,14 @@ def _process_request( ...@@ -104,15 +104,14 @@ def _process_request(
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else: else:
image_url = input_item.image_url.url image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb") image_stream = open(image_url, "rb")
else: # web uri else: # web uri
image_path = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB") image = Image.open(image_stream).convert("RGB")
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
...@@ -230,8 +229,9 @@ async def create_stream_chat_completion_response( ...@@ -230,8 +229,9 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex)
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
...@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti ...@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
...@@ -35,6 +35,12 @@ class Response: ...@@ -35,6 +35,12 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"] model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
...@@ -48,7 +54,11 @@ class BaseEngine(ABC): ...@@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args: "DataArguments", data_args: "DataArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ... ) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod @abstractmethod
async def chat( async def chat(
...@@ -56,9 +66,14 @@ class BaseEngine(ABC): ...@@ -56,9 +66,14 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
...@@ -66,13 +81,22 @@ class BaseEngine(ABC): ...@@ -66,13 +81,22 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ... ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...
...@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine ...@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
...@@ -38,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: ...@@ -38,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface": if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
...@@ -56,10 +64,16 @@ class ChatModel: ...@@ -56,10 +64,16 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
return task.result() return task.result()
async def achat( async def achat(
...@@ -67,20 +81,28 @@ class ChatModel: ...@@ -67,20 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs) r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs) r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
...@@ -93,10 +115,14 @@ class ChatModel: ...@@ -93,10 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(
...@@ -104,6 +130,9 @@ class ChatModel: ...@@ -104,6 +130,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
...@@ -112,6 +141,9 @@ class ChatModel: ...@@ -112,6 +141,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)
......
...@@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt ...@@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
...@@ -29,12 +31,11 @@ from .base_engine import BaseEngine, Response ...@@ -29,12 +31,11 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
...@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model( self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab
...@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine): ...@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if ( mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
processor is not None if image is not None:
and image is not None mm_input_dict.update({"images": [image], "imglens": [1]})
and not hasattr(processor, "image_seq_length") if IMAGE_PLACEHOLDER not in messages[0]["content"]:
and template.image_token not in messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"] if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
pixel_values = None prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.mm_plugin.process_token_ids(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
) )
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool) attention_mask = torch.ones_like(inputs, dtype=torch.bool)
...@@ -164,8 +164,10 @@ class HuggingfaceEngine(BaseEngine): ...@@ -164,8 +164,10 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if pixel_values is not None: mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
gen_kwargs["pixel_values"] = pixel_values for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
...@@ -180,11 +182,12 @@ class HuggingfaceEngine(BaseEngine): ...@@ -180,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
...@@ -215,11 +218,12 @@ class HuggingfaceEngine(BaseEngine): ...@@ -215,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
...@@ -242,37 +246,28 @@ class HuggingfaceEngine(BaseEngine): ...@@ -242,37 +246,28 @@ class HuggingfaceEngine(BaseEngine):
batch_input: List[str], batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]: ) -> List[float]:
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda") device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer( inputs: Dict[str, "torch.Tensor"] = tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt", return_tensors="pt",
add_special_tokens=True, add_special_tokens=False,
).to(device) ).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
input_ids: torch.Tensor = inputs["input_ids"] scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores return scores
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
...@@ -289,18 +284,21 @@ class HuggingfaceEngine(BaseEngine): ...@@ -289,18 +284,21 @@ class HuggingfaceEngine(BaseEngine):
system, system,
tools, tools,
image, image,
video,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args) return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
...@@ -317,6 +315,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -317,6 +315,7 @@ class HuggingfaceEngine(BaseEngine):
system, system,
tools, tools,
image, image,
video,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
...@@ -328,6 +327,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -328,6 +327,7 @@ class HuggingfaceEngine(BaseEngine):
except StopAsyncIteration: except StopAsyncIteration:
break break
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],
......
...@@ -15,32 +15,31 @@ ...@@ -15,32 +15,31 @@
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if is_vllm_version_greater_than_0_5_1():
pass
elif is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from ..data.mm_plugin import ImageInput, VideoInput
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
...@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine): ...@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = { engine_args = {
...@@ -85,19 +84,11 @@ class VllmEngine(BaseEngine): ...@@ -85,19 +84,11 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if model_args.visual_inputs: if getattr(config, "is_yi_vl_derived_model", None):
image_size = config.vision_config.image_size import vllm.model_executor.models.llava
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
...@@ -110,37 +101,18 @@ class VllmEngine(BaseEngine): ...@@ -110,37 +101,18 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if image is not None:
if ( if IMAGE_PLACEHOLDER not in messages[0]["content"]:
self.processor is not None messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if is_vllm_version_greater_than_0_5_1():
multi_modal_data = {"image": pixel_values}
elif is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1 use_beam_search: bool = self.generating_args["num_beams"] > 1
...@@ -185,6 +157,17 @@ class VllmEngine(BaseEngine): ...@@ -185,6 +157,17 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True, skip_special_tokens=True,
) )
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image}
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -193,16 +176,18 @@ class VllmEngine(BaseEngine): ...@@ -193,16 +176,18 @@ class VllmEngine(BaseEngine):
) )
return result_generator return result_generator
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
...@@ -219,21 +204,24 @@ class VllmEngine(BaseEngine): ...@@ -219,21 +204,24 @@ class VllmEngine(BaseEngine):
return results return results
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text
yield delta_text yield delta_text
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],
......
...@@ -118,4 +118,4 @@ def main(): ...@@ -118,4 +118,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}".format(command)) raise NotImplementedError("Unknown command: {}.".format(command))
...@@ -12,7 +12,12 @@ ...@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask from .collator import (
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
...@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer ...@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask", "SFTDataCollatorWith4DAttentionMask",
"Role", "Role",
......
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from datasets import Features
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
...@@ -27,88 +25,117 @@ if TYPE_CHECKING: ...@@ -27,88 +25,117 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(
images: Sequence["ImageInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
outputs = [] if len(images) == 0:
return None
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i])
return images
def _convert_videos(
videos: Sequence["VideoInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
return None
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for image in images: for i in range(len(videos)):
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)): if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
outputs.append(os.path.join(data_args.dataset_dir, image)) videos[i] = os.path.join(data_args.dataset_dir, videos[i])
else:
outputs.append(image)
return outputs return videos
def convert_alpaca( def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts alpaca format dataset to the standard format. Converts alpaca format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])): convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
prompt = [] output = {
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): "_prompt": prompt,
for old_prompt, old_response in examples[dataset_attr.history][i]: "_response": response,
prompt.append({"role": Role.USER.value, "content": old_prompt}) "_system": example[dataset_attr.system] if dataset_attr.system else "",
prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
content = [] "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
if dataset_attr.prompt and examples[dataset_attr.prompt][i]: }
content.append(examples[dataset_attr.prompt][i]) return output
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt( def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts sharegpt format dataset to the standard format. Converts sharegpt format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value, dataset_attr.assistant_tag: Role.ASSISTANT.value,
...@@ -119,74 +146,79 @@ def convert_sharegpt( ...@@ -119,74 +146,79 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags) accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): messages = example[dataset_attr.messages]
if len(messages) == 0: if (
continue dataset_attr.system_tag
and len(messages) != 0
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
system = messages[0][dataset_attr.content_tag] ):
messages = messages[1:] system = messages[0][dataset_attr.content_tag]
else: messages = messages[1:]
system = examples[dataset_attr.system][i] if dataset_attr.system else "" else:
system = example[dataset_attr.system] if dataset_attr.system else ""
aligned_messages = [] aligned_messages = []
broken_data = False broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages)) logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True broken_data = True
aligned_messages.append( aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
) )
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0 dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
): ):
logger.warning("Invalid message count in {}.".format(messages)) logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example prompt = aligned_messages
prompt = aligned_messages[:-1] response = [
response = aligned_messages[-1:] {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
if examples[dataset_attr.kto_tag][i]: {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
response = response + [{"role": Role.ASSISTANT.value, "content": ""}] ]
else: else: # normal example
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response prompt = aligned_messages[:-1]
elif ( response = aligned_messages[-1:]
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict) if broken_data:
and isinstance(examples[dataset_attr.rejected][i], dict) logger.warning("Skipping this abnormal example.")
): # pairwise example prompt, response = [], []
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i] convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
if ( convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
chosen[dataset_attr.role_tag] not in accept_tags[-1] output = {
or rejected[dataset_attr.role_tag] not in accept_tags[-1] "_prompt": prompt,
): "_response": response,
logger.warning("Invalid role tag in {}.".format([chosen, rejected])) "_system": system,
broken_data = True "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
prompt = aligned_messages "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
response = [ }
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, return output
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def align_dataset( def align_dataset(
...@@ -197,11 +229,12 @@ def align_dataset( ...@@ -197,11 +229,12 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." _system: "..."
tools: "...", _tools: "...",
images: [], _images: [],
_videos: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
...@@ -209,19 +242,6 @@ def align_dataset( ...@@ -209,19 +242,6 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
...@@ -232,8 +252,7 @@ def align_dataset( ...@@ -232,8 +252,7 @@ def align_dataset(
return dataset.map( return dataset.map(
convert_func, convert_func,
batched=True, batched=False,
remove_columns=column_names, remove_columns=column_names,
features=features,
**kwargs, **kwargs,
) )
...@@ -16,12 +16,18 @@ ...@@ -16,12 +16,18 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
...@@ -62,7 +68,42 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype ...@@ -62,7 +68,42 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for 4d attention mask. Data collator for 4d attention mask.
""" """
...@@ -80,7 +121,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): ...@@ -80,7 +121,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
...@@ -99,20 +140,16 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): ...@@ -99,20 +140,16 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)], "labels": feature["{}_labels".format(key)],
"images": feature["images"],
"videos": feature["videos"],
} }
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
return super().__call__(concatenated_features) return super().__call__(concatenated_features)
@dataclass @dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for KTO data. Data collator for KTO data.
""" """
...@@ -126,19 +163,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): ...@@ -126,19 +163,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["input_ids"], "input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"], "attention_mask": feature["attention_mask"],
"labels": feature["labels"], "labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
kl_feature = { kl_feature = {
"input_ids": feature["kl_input_ids"], "input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"], "attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"]) kto_tags.append(feature["kto_tags"])
...@@ -148,7 +182,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): ...@@ -148,7 +182,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch: if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags) batch["kto_tags"] = torch.tensor(kto_tags)
......
...@@ -49,6 +49,9 @@ class DatasetModule(TypedDict): ...@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
...@@ -67,14 +70,16 @@ def merge_dataset( ...@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional). Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
""" """
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
......
...@@ -16,21 +16,36 @@ import json ...@@ -16,21 +16,36 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None tool_format: Optional[str] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: Each tuple consists of function name and function arguments.
"""
raise NotImplementedError raise NotImplementedError
...@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter): ...@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if has_placeholder: if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.") raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
return self.slots return self.slots
...@@ -60,6 +76,7 @@ class StringFormatter(Formatter): ...@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
if not has_placeholder: if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.") raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
elements = [] elements = []
for slot in self.slots: for slot in self.slots:
...@@ -81,13 +98,9 @@ class StringFormatter(Formatter): ...@@ -81,13 +98,9 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = [] functions: List[Tuple[str, str]] = []
...@@ -100,7 +113,7 @@ class FunctionFormatter(Formatter): ...@@ -100,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
functions = [] raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
elements = [] elements = []
for name, arguments in functions: for name, arguments in functions:
...@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter): ...@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.tool_utils = get_tool_utils(self.tool_format)
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: @override
return self._tool_extractor(content) def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)
...@@ -27,7 +27,6 @@ from .aligner import align_dataset ...@@ -27,7 +27,6 @@ from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -49,6 +48,9 @@ def _load_single_dataset( ...@@ -49,6 +48,9 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
...@@ -118,7 +120,7 @@ def _load_single_dataset( ...@@ -118,7 +120,7 @@ def _load_single_dataset(
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes) target_num -= len(indexes)
if target_num > 0: if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num) expand_indexes = np.random.choice(len(dataset), target_num)
...@@ -142,6 +144,9 @@ def _get_merged_dataset( ...@@ -142,6 +144,9 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
...@@ -165,6 +170,9 @@ def _get_preprocessed_dataset( ...@@ -165,6 +170,9 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
...@@ -180,7 +188,13 @@ def _get_preprocessed_dataset( ...@@ -180,7 +188,13 @@ def _get_preprocessed_dataset(
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if training_args.should_log: if training_args.should_log:
try: try:
...@@ -196,6 +210,7 @@ def _get_preprocessed_dataset( ...@@ -196,6 +210,7 @@ def _get_preprocessed_dataset(
def get_dataset( def get_dataset(
template: "Template",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
...@@ -203,10 +218,9 @@ def get_dataset( ...@@ -203,10 +218,9 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) r"""
if data_args.train_on_prompt and template.efficient_eos: Gets the train dataset and optionally gets the evaluation dataset.
raise ValueError("Current template does not support `train_on_prompt`.") """
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
...@@ -217,6 +231,7 @@ def get_dataset( ...@@ -217,6 +231,7 @@ def get_dataset(
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
...@@ -270,6 +285,7 @@ def get_dataset( ...@@ -270,6 +285,7 @@ def get_dataset(
dataset_module = {} dataset_module = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
......
This diff is collapsed.
...@@ -43,6 +43,7 @@ class DatasetAttr: ...@@ -43,6 +43,7 @@ class DatasetAttr:
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
images: Optional[str] = None images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns # rlhf columns
chosen: Optional[str] = None chosen: Optional[str] = None
rejected: Optional[str] = None rejected: Optional[str] = None
...@@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_attr.set_attr("num_samples", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
else: else:
......
...@@ -50,7 +50,7 @@ def get_preprocess_and_print_func( ...@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
...@@ -67,6 +67,7 @@ def get_preprocess_and_print_func( ...@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor,
data_args=data_args, data_args=data_args,
) )
else: else:
......
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
...@@ -35,14 +37,13 @@ def _encode_feedback_example( ...@@ -35,14 +37,13 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]], kl_response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
...@@ -55,6 +56,8 @@ def _encode_feedback_example( ...@@ -55,6 +56,8 @@ def _encode_feedback_example(
else: else:
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
...@@ -62,10 +65,8 @@ def _encode_feedback_example( ...@@ -62,10 +65,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
...@@ -78,7 +79,6 @@ def _encode_feedback_example( ...@@ -78,7 +79,6 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
...@@ -88,35 +88,23 @@ def preprocess_feedback_dataset( ...@@ -88,35 +88,23 @@ def preprocess_feedback_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1] kl_response = examples["_response"][::-1]
model_inputs = { model_inputs = defaultdict(list)
"input_ids": [], for i in range(len(examples["_prompt"])):
"attention_mask": [], if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
"labels": [], logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
kl_response=kl_response[i], kl_response=kl_response[i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
...@@ -129,11 +117,8 @@ def preprocess_feedback_dataset( ...@@ -129,11 +117,8 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
......
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
...@@ -34,16 +36,15 @@ def _encode_pairwise_example( ...@@ -34,16 +36,15 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
prompt[0]["content"] = template.image_token + prompt[0]["content"] rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
...@@ -51,10 +52,7 @@ def _encode_pairwise_example( ...@@ -51,10 +52,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
# consider the response is more important # consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
...@@ -65,7 +63,6 @@ def _encode_pairwise_example( ...@@ -65,7 +63,6 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
...@@ -75,32 +72,21 @@ def preprocess_pairwise_dataset( ...@@ -75,32 +72,21 @@ def preprocess_pairwise_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = { model_inputs = defaultdict(list)
"chosen_input_ids": [], for i in range(len(examples["_prompt"])):
"chosen_attention_mask": [], if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
"chosen_labels": [], logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
...@@ -112,15 +98,8 @@ def preprocess_pairwise_dataset( ...@@ -112,15 +98,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs return model_inputs
......
...@@ -27,16 +27,16 @@ if TYPE_CHECKING: ...@@ -27,16 +27,16 @@ if TYPE_CHECKING:
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples] text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True) result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else: else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment