Commit 2778a3d0 authored by luopl's avatar luopl
Browse files

updata to v0.9.1_stable

parent e92143e3
......@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
......@@ -28,7 +28,7 @@ def get_version() -> str:
def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
......@@ -54,13 +54,14 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.2"],
"vllm": ["vllm>=0.4.3,<0.6.4"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"dev": ["ruff", "pytest"],
"openmind": ["openmind"],
"dev": ["pre-commit", "ruff", "pytest"],
}
......@@ -71,7 +72,7 @@ def main():
author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License",
......
......@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.getenv("API_PORT", "8000"))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)
......
......@@ -20,17 +20,17 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.45.2
datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.34.2
transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=3.1.0
accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.45.2
transformers>=4.41.2,<=4.46.1
packing:
transformers>=4.41.2,<=4.45.2
transformers>=4.41.2,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
......@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
"""
from .extras.env import VERSION
......
......@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware(
CORSMiddleware,
......@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY", None)
api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
......@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card])
@app.post(
......@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.getenv("API_PORT", "8000"))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)
......@@ -21,7 +21,7 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras import logging
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
......@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
......@@ -69,8 +69,8 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
......@@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
images = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
......@@ -111,7 +111,7 @@ def _process_request(
else: # web uri
image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB")
images.append(Image.open(image_stream).convert("RGB"))
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
......@@ -124,7 +124,7 @@ def _process_request(
else:
tools = None
return input_messages, system, tools, image
return input_messages, system, tools, images or None
def _create_stream_chat_completion_chunk(
......@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
images,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
......@@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = []
for tool in result:
function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
......@@ -193,8 +193,8 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
......@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages,
system,
tools,
image,
images,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
......@@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex)
score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
......
......@@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
......@@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
......
......@@ -53,7 +53,7 @@ class ChatModel:
elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
......@@ -64,15 +64,15 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
)
return task.result()
......@@ -81,28 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
......@@ -115,14 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token
def get_scores(
......
......@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
......@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class HuggingfaceEngine(BaseEngine):
......@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning("There is no current event loop, creating a new one.")
logger.warning_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod
def _process_args(
......@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None:
mm_input_dict.update({"images": [image], "imglens": [1]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
......@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
generating_args = generating_args.copy()
generating_args.update(
......@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length
......@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
......@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
......@@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
......@@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages,
system,
tools,
image,
video,
images,
videos,
input_kwargs,
)
async with self.semaphore:
......@@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
......@@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages,
system,
tools,
image,
video,
images,
videos,
input_kwargs,
)
async with self.semaphore:
......
......@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
......@@ -43,7 +43,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine):
......@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine):
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
......@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
request_id = f"chatcmpl-{uuid.uuid4().hex}"
if images is not None:
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
else:
image_str = self.template.mm_plugin.image_token or ""
paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
......@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine):
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
......@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine):
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
skip_special_tokens=True,
)
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
if images is not None: # add image features
image_data = []
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image_data.append(image)
multi_modal_data = {"image": image}
multi_modal_data = {"image": image_data}
else:
multi_modal_data = None
result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params,
request_id=request_id,
lora_request=self.lora_request,
......@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator:
final_output = request_output
......@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
......
......@@ -22,8 +22,8 @@ from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
......@@ -47,7 +47,7 @@ USAGE = (
WELCOME = (
"-" * 58
+ "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
......@@ -56,7 +56,7 @@ WELCOME = (
+ "-" * 58
)
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
@unique
......@@ -86,25 +86,26 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
)
.format(
nnodes=os.getenv("NNODES", "1"),
node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
.split()
)
sys.exit(process.returncode)
else:
......@@ -118,4 +119,4 @@ def main():
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}.".format(command))
raise NotImplementedError(f"Unknown command: {command}.")
......@@ -16,7 +16,7 @@ import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger
from ..extras import logging
from .data_utils import Role
......@@ -29,45 +29,51 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _convert_images(
images: Sequence["ImageInput"],
images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if len(images) == 0:
if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None
else:
images = images[:]
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i])
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.image_dir, images[i])
return images
def _convert_videos(
videos: Sequence["VideoInput"],
videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None
else:
videos = videos[:]
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos
......@@ -161,7 +167,7 @@ def convert_sharegpt(
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
aligned_messages.append(
......@@ -171,7 +177,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
......@@ -192,7 +198,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
......@@ -205,7 +211,7 @@ def convert_sharegpt(
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
......
......@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
......@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
batch_input_ids.append(feature["input_ids"])
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
......@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features
......@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
"input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature[f"{key}_attention_mask"],
"labels": feature[f"{key}_labels"],
"images": feature["images"],
"videos": feature["videos"],
}
......
......@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger
from ..extras import logging
if TYPE_CHECKING:
......@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
......@@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
......@@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset(
......
......@@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements
......@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = []
for name, arguments in functions:
......@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements
......@@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
......
......@@ -20,8 +20,8 @@ import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
......@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from .template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _load_single_dataset(
......@@ -51,9 +51,9 @@ def _load_single_dataset(
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr))
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
......@@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File {} not found.".format(local_path))
raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
......@@ -98,10 +97,27 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
use_streaming=data_args.streaming,
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
else:
dataset = load_dataset(
path=data_path,
......@@ -111,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
streaming=data_args.streaming,
trust_remote_code=True,
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
......@@ -128,7 +141,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset))
......@@ -224,9 +237,9 @@ def get_dataset(
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict:
......@@ -277,8 +290,8 @@ def get_dataset(
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)
......
This diff is collapsed.
......@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
from ..extras.misc import use_modelscope, use_openmind
@dataclass
......@@ -30,7 +30,7 @@ class DatasetAttr:
"""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
......@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG)
try:
with open(config_path, "r") as f:
with open(config_path) as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None
dataset_list: List["DatasetAttr"] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
load_from = "ms_hub" if use_modelscope() else "hf_hub"
if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
has_om_url = "om_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
if has_hf_url or has_ms_url or has_om_url:
if has_ms_url and (use_modelscope() or not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
elif has_om_url and (use_openmind() or not has_hf_url):
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
......
......@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_feedback_example(
......@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
......@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
......@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_pairwise_example(
......@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
......@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
......@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_supervised_example(
......@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
......@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
......@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
......@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment