Commit 8293100a authored by luopl's avatar luopl
Browse files

update to 0.9.2.dev0

parent 2778a3d0
...@@ -168,7 +168,7 @@ async def create_chat_completion_response( ...@@ -168,7 +168,7 @@ async def create_chat_completion_response(
if isinstance(result, list): if isinstance(result, list):
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
function = Function(name=tool[0], arguments=tool[1]) function = Function(name=tool.name, arguments=tool.arguments)
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function)) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
......
...@@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
try: try:
asyncio.get_event_loop() asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.warning_once("There is no current event loop, creating a new one.") logger.warning_rank0_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
...@@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine):
if repetition_penalty is not None if repetition_penalty is not None
else generating_args["repetition_penalty"], else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, eos_token_id=template.get_stop_token_ids(tokenizer),
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
) )
) )
...@@ -168,11 +168,21 @@ class HuggingfaceEngine(BaseEngine): ...@@ -168,11 +168,21 @@ class HuggingfaceEngine(BaseEngine):
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor): elif not isinstance(value, torch.Tensor):
value = torch.tensor(value) value = torch.tensor(value)
if torch.is_floating_point(value): # cast data dtype for paligemma
value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@staticmethod @staticmethod
...@@ -204,8 +214,13 @@ class HuggingfaceEngine(BaseEngine): ...@@ -204,8 +214,13 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs, input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # post-process the minicpm_o output
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
)
results = [] results = []
for i in range(len(response)): for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
...@@ -249,7 +264,9 @@ class HuggingfaceEngine(BaseEngine): ...@@ -249,7 +264,9 @@ class HuggingfaceEngine(BaseEngine):
videos, videos,
input_kwargs, input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start() thread.start()
......
...@@ -19,7 +19,7 @@ from typing_extensions import override ...@@ -19,7 +19,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
...@@ -67,11 +67,12 @@ class VllmEngine(BaseEngine): ...@@ -67,11 +67,12 @@ class VllmEngine(BaseEngine):
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = { engine_args = {
"model": model_args.model_name_or_path, "model": model_args.model_name_or_path,
"trust_remote_code": True, "trust_remote_code": model_args.trust_remote_code,
"download_dir": model_args.cache_dir, "download_dir": model_args.cache_dir,
"dtype": model_args.infer_dtype, "dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen, "max_model_len": model_args.vllm_maxlen,
...@@ -83,6 +84,9 @@ class VllmEngine(BaseEngine): ...@@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
...@@ -108,19 +112,21 @@ class VllmEngine(BaseEngine): ...@@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}" request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if images is not None: if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution if videos is not None:
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>" mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
else: if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
image_str = self.template.mm_plugin.image_token or "" messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
paired_messages = [ messages = self.template.mm_plugin.process_messages(
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)} messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
for message in messages )
] + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
...@@ -162,13 +168,13 @@ class VllmEngine(BaseEngine): ...@@ -162,13 +168,13 @@ class VllmEngine(BaseEngine):
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"], top_k=top_k if top_k is not None else self.generating_args["top_k"],
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=self.generating_args["skip_special_tokens"],
) )
if images is not None: # add image features if images is not None: # add image features
image_data = [] multi_modal_data = {"image": []}
for image in images: for image in images:
if not isinstance(image, (str, ImageObject)): if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
...@@ -176,9 +182,7 @@ class VllmEngine(BaseEngine): ...@@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image).convert("RGB") image = Image.open(image).convert("RGB")
image_data.append(image) multi_modal_data["image"].append(image)
multi_modal_data = {"image": image_data}
else: else:
multi_modal_data = None multi_modal_data = None
......
...@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat ...@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count from .extras.misc import get_device_count, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
...@@ -87,7 +87,7 @@ def main(): ...@@ -87,7 +87,7 @@ def main():
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
...@@ -120,3 +120,7 @@ def main(): ...@@ -120,3 +120,7 @@ def main():
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError(f"Unknown command: {command}.") raise NotImplementedError(f"Unknown command: {command}.")
if __name__ == "__main__":
main()
...@@ -19,8 +19,16 @@ from dataclasses import dataclass ...@@ -19,8 +19,16 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import ProcessorMixin
...@@ -72,12 +80,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -72,12 +80,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""
Data collator that supports VLMs. Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images. Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
""" """
template: Optional["Template"] = None template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None processor: Optional["ProcessorMixin"] = None
def __post_init__(self):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features: for feature in features:
...@@ -89,6 +101,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -89,6 +101,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
)
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images
batch_imglens[0] = 1
batch_input_ids[0] = features[0]["input_ids"]
mm_inputs = self.template.mm_plugin.get_mm_inputs( mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
) )
...@@ -98,10 +133,30 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -98,10 +133,30 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature["token_type_ids"] = token_type_ids[i] feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
attention_mask=features["attention_mask"],
)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to() features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs
bsz, seq_length = features["input_ids"].shape
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
return features return features
...@@ -120,6 +175,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): ...@@ -120,6 +175,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)
return features return features
......
...@@ -56,12 +56,12 @@ def merge_dataset( ...@@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.") logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets) return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"): elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming: if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
......
...@@ -16,16 +16,12 @@ import json ...@@ -16,16 +16,12 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import List, Optional, Union
from typing_extensions import override from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
from .tool_utils import get_tool_utils from .tool_utils import FunctionCall, get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass @dataclass
...@@ -98,33 +94,31 @@ class StringFormatter(Formatter): ...@@ -98,33 +94,31 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots self.tool_utils = get_tool_utils(self.tool_format)
@override @override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = [] functions: List["FunctionCall"] = []
try: try:
tool_calls = json.loads(content) tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls] tool_calls = [tool_calls]
for tool_call in tool_calls: for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions:
for slot in self.slots: for slot in self.slots:
if isinstance(slot, str): if slot == "{{content}}":
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) elements += self.tool_utils.function_formatter(functions)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else: else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") elements.append(slot)
return elements return elements
......
...@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union ...@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import has_tokenized_data from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
...@@ -84,7 +83,7 @@ def _load_single_dataset( ...@@ -84,7 +83,7 @@ def _load_single_dataset(
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import MsDataset # type: ignore from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
...@@ -103,7 +102,7 @@ def _load_single_dataset( ...@@ -103,7 +102,7 @@ def _load_single_dataset(
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub": elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") check_version("openmind>=0.8.0", mandatory=True)
from openmind import OmDataset # type: ignore from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
...@@ -128,7 +127,8 @@ def _load_single_dataset( ...@@ -128,7 +127,8 @@ def _load_single_dataset(
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=data_args.streaming, streaming=data_args.streaming,
trust_remote_code=True, num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
) )
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
...@@ -238,15 +238,19 @@ def get_dataset( ...@@ -238,15 +238,19 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if isinstance(tokenized_data, DatasetDict):
dataset_module["train_dataset"] = dataset_dict["train"] if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if "validation" in dataset_dict: if "validation" in tokenized_data:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming: if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
......
This diff is collapsed.
...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple ...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import ( from .processors.supervised import (
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
preprocess_supervised_dataset, preprocess_supervised_dataset,
...@@ -47,7 +47,7 @@ def get_preprocess_and_print_func( ...@@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
tokenizer=tokenizer, tokenizer=tokenizer,
data_args=data_args, data_args=data_args,
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask if data_args.neat_packing: # hack datasets to have int32 attention mask
......
...@@ -52,3 +52,8 @@ def preprocess_pretrain_dataset( ...@@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
result["input_ids"][i][0] = tokenizer.bos_token_id result["input_ids"][i][0] = tokenizer.bos_token_id
return result return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
...@@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset( ...@@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
This diff is collapsed.
...@@ -15,15 +15,20 @@ ...@@ -15,15 +15,20 @@
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
class FunctionCall(NamedTuple):
name: str
arguments: str
DEFAULT_TOOL_PROMPT = ( DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}" "You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n" "Use the following format if using a tool:\n"
...@@ -34,14 +39,25 @@ DEFAULT_TOOL_PROMPT = ( ...@@ -34,14 +39,25 @@ DEFAULT_TOOL_PROMPT = (
"```\n" "```\n"
) )
GLM4_TOOL_PROMPT = ( GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
) )
LLAMA3_TOOL_PROMPT = (
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.\n\n{tool_text}"
)
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"]) QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
)
@dataclass @dataclass
...@@ -52,17 +68,17 @@ class ToolUtils(ABC): ...@@ -52,17 +68,17 @@ class ToolUtils(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_function_slots() -> SLOTS: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r""" r"""
Gets a list of slots corresponding to a single function call. Generates the system message describing all the available tools.
""" """
... ...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
r""" r"""
Generates the system message describing all the available tools. Generates the assistant message including all the tool calls.
""" """
... ...
...@@ -70,16 +86,17 @@ class ToolUtils(ABC): ...@@ -70,16 +86,17 @@ class ToolUtils(ABC):
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r""" r"""
Extracts all the function calls from the response message. Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
""" """
... ...
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
@override r"""
@staticmethod Default tool using template.
def get_function_slots() -> SLOTS: """
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override @override
@staticmethod @staticmethod
...@@ -115,6 +132,15 @@ class DefaultToolUtils(ToolUtils): ...@@ -115,6 +132,15 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text]
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
...@@ -129,7 +155,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -129,7 +155,7 @@ class DefaultToolUtils(ToolUtils):
tool_input = match[1].strip().strip('"').strip("```") tool_input = match[1].strip().strip('"').strip("```")
try: try:
arguments = json.loads(tool_input) arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
...@@ -137,10 +163,9 @@ class DefaultToolUtils(ToolUtils): ...@@ -137,10 +163,9 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
@override r"""
@staticmethod GLM-4 tool using template.
def get_function_slots() -> SLOTS: """
return ["{{name}}\n{{arguments}}"]
@override @override
@staticmethod @staticmethod
...@@ -153,6 +178,14 @@ class GLM4ToolUtils(ToolUtils): ...@@ -153,6 +178,14 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text) return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"]
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
...@@ -161,16 +194,152 @@ class GLM4ToolUtils(ToolUtils): ...@@ -161,16 +194,152 @@ class GLM4ToolUtils(ToolUtils):
tool_name, tool_input = content.split("\n", maxsplit=1) tool_name, tool_input = content.split("\n", maxsplit=1)
try: try:
arguments = json.loads(tool_input) arguments = json.loads(tool_input.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))] return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "parameters" not in tool:
return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content
if not isinstance(tools, list):
tools = [tools]
results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
if not tool_match:
return content
results = []
for tool in tool_match:
try:
tool = json.loads(tool.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
TOOLS = { TOOLS = {
"default": DefaultToolUtils(), "default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(), "glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
} }
......
...@@ -100,7 +100,7 @@ class Evaluator: ...@@ -100,7 +100,7 @@ class Evaluator:
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode, download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
trust_remote_code=True, trust_remote_code=self.model_args.trust_remote_code,
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []
......
This diff is collapsed.
...@@ -26,7 +26,7 @@ import trl ...@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.1" VERSION = "0.9.2.dev0"
def print_env() -> None: def print_env() -> None:
......
...@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler): ...@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger): class _Logger(logging.Logger):
r""" r"""
A logger that supports info_rank0 and warning_once. A logger that supports rank0 logging.
""" """
def info_rank0(self, *args, **kwargs) -> None: def info_rank0(self, *args, **kwargs) -> None:
...@@ -77,7 +77,7 @@ class _Logger(logging.Logger): ...@@ -77,7 +77,7 @@ class _Logger(logging.Logger):
def warning_rank0(self, *args, **kwargs) -> None: def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None: def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
...@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None: ...@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
@lru_cache(None) @lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None: def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0 logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0 logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once logging.Logger.warning_rank0_once = warning_rank0_once
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import gc import gc
import os import os
from typing import TYPE_CHECKING, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -73,18 +73,46 @@ class AverageMeter: ...@@ -73,18 +73,46 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None: def check_dependencies() -> None:
r""" r"""
Checks the version of the required packages. Checks the version of the required packages.
""" """
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: check_version("transformers>=4.41.2,<=4.46.1")
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") check_version("datasets>=2.16.0,<=3.1.0")
else: check_version("accelerate>=0.34.0,<=1.0.1")
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") check_version("peft>=0.11.1,<=0.12.0")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0") check_version("trl>=0.8.6,<=0.9.6")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
effective_token_num = 0
for data in dataset:
if stage == "sft":
effective_token_num += len(data["input_ids"])
elif stage == "rm":
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
...@@ -213,7 +241,7 @@ def skip_check_imports() -> None: ...@@ -213,7 +241,7 @@ def skip_check_imports() -> None:
r""" r"""
Avoids flash attention import error in custom model files. Avoids flash attention import error in custom model files.
""" """
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
...@@ -237,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: ...@@ -237,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path return model_args.model_name_or_path
if use_modelscope(): if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
...@@ -248,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: ...@@ -248,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
) )
if use_openmind(): if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download( return snapshot_download(
...@@ -259,16 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: ...@@ -259,16 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool: def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int: def use_ray() -> bool:
r""" return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result
...@@ -50,6 +50,10 @@ def is_galore_available(): ...@@ -50,6 +50,10 @@ def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
def is_apollo_available():
return _is_package_available("apollo_torch")
def is_gradio_available(): def is_gradio_available():
return _is_package_available("gradio") return _is_package_available("gradio")
...@@ -62,6 +66,10 @@ def is_pillow_available(): ...@@ -62,6 +66,10 @@ def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available(): def is_requests_available():
return _is_package_available("requests") return _is_package_available("requests")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment