Commit 581d366d authored by chenych's avatar chenych
Browse files

Support GLM-4/GLM-4-0414/GLM-Z1

parent 428c5813
...@@ -26,6 +26,7 @@ save_steps: 500 ...@@ -26,6 +26,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -26,14 +26,21 @@ save_steps: 500 ...@@ -26,14 +26,21 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray ### ray
ray_run_name: llama3_8b_sft_lora ray_run_name: llama3_8b_sft_lora
ray_storage_path: ./saves ray_storage_path: ./saves
ray_num_workers: 4 # number of GPUs to use ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK
resources_per_worker: resources_per_worker:
GPU: 1 GPU: 1
placement_strategy: PACK # ray_init_kwargs:
# runtime_env:
# env_vars:
# <YOUR-ENV-VAR-HERE>: "<YOUR-ENV-VAR-HERE>"
# pip:
# - emoji
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -28,10 +28,11 @@ save_steps: 500 ...@@ -28,10 +28,11 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 2
learning_rate: 1.0e-4 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
......
...@@ -25,6 +25,7 @@ save_steps: 500 ...@@ -25,6 +25,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -29,6 +29,7 @@ save_steps: 500 ...@@ -29,6 +29,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -27,6 +27,7 @@ save_steps: 500 ...@@ -27,6 +27,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -16,6 +16,7 @@ cutoff_len: 2048 ...@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
...@@ -23,6 +24,8 @@ logging_steps: 10 ...@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -16,6 +16,7 @@ cutoff_len: 2048 ...@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
...@@ -23,6 +24,8 @@ logging_steps: 10 ...@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4
quantization_method: bitsandbytes quantization_method: bnb
double_quantization: false double_quantization: false
trust_remote_code: true trust_remote_code: true
...@@ -19,6 +19,7 @@ cutoff_len: 2048 ...@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
...@@ -26,6 +27,8 @@ logging_steps: 10 ...@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -16,6 +16,7 @@ cutoff_len: 2048 ...@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
...@@ -23,6 +24,8 @@ logging_steps: 10 ...@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)] quantization_method: bnb # choices: [bnb, hqq, eetq]
trust_remote_code: true trust_remote_code: true
### method ### method
...@@ -18,6 +18,7 @@ cutoff_len: 2048 ...@@ -18,6 +18,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
...@@ -25,6 +26,8 @@ logging_steps: 10 ...@@ -25,6 +26,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
......
...@@ -19,13 +19,36 @@ dynamic = [ ...@@ -19,13 +19,36 @@ dynamic = [
] ]
[tool.ruff] [tool.ruff]
target-version = "py38" target-version = "py39"
line-length = 119 line-length = 119
indent-width = 4 indent-width = 4
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] ignore = [
select = ["C", "E", "F", "I", "W"] "C408", # collection
"C901", # complex
"E501", # line too long
"E731", # lambda function
"E741", # ambiguous var name
"D100", # no doc public module
"D101", # no doc public class
"D102", # no doc public method
"D103", # no doc public function
"D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
]
extend-select = [
"C", # complexity
"E", # error
"F", # pyflakes
"I", # isort
"W", # warning
"UP", # pyupgrade
"D", # pydocstyle
"PT009", # pytest assert
"RUF022", # sort __all__
]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
...@@ -38,9 +61,12 @@ known-third-party = [ ...@@ -38,9 +61,12 @@ known-third-party = [
"peft", "peft",
"torch", "torch",
"transformers", "transformers",
"trl" "trl",
] ]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.format] [tool.ruff.format]
quote-style = "double" quote-style = "double"
indent-style = "space" indent-style = "space"
...@@ -61,5 +87,9 @@ conflicts = [ ...@@ -61,5 +87,9 @@ conflicts = [
[ [
{ extra = "torch-npu" }, { extra = "torch-npu" },
{ extra = "vllm" }, { extra = "vllm" },
] ],
[
{ extra = "sglang" },
{ extra = "minicpm_v" },
],
] ]
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0 transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.4.1 datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.5.2 accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.0 peft>=0.14.0,<=0.15.1
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
tokenizers>=0.19.0,<=0.21.0 tokenizers>=0.19.0,<=0.21.1
gradio>=4.38.0,<=5.21.0 gradio>=4.38.0,<=5.25.0
pandas>=2.0.0
scipy scipy
einops einops
sentencepiece sentencepiece
tiktoken tiktoken
protobuf protobuf
uvicorn uvicorn
pydantic
fastapi fastapi
sse-starlette sse-starlette
matplotlib>=3.7.0 matplotlib>=3.7.0
...@@ -21,6 +19,7 @@ packaging ...@@ -21,6 +19,7 @@ packaging
pyyaml pyyaml
numpy<2.0.0 numpy<2.0.0
pydantic<=2.10.6 pydantic<=2.10.6
pandas>=2.0.0
av av
librosa librosa
tyro<0.9.0 tyro<0.9.0
...@@ -21,9 +21,9 @@ from datasets import load_dataset ...@@ -21,9 +21,9 @@ from datasets import load_dataset
try: try:
import jieba import jieba # type: ignore
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
from rouge_chinese import Rouge from rouge_chinese import Rouge # type: ignore
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
jieba.initialize() jieba.initialize()
...@@ -52,6 +52,7 @@ def compute_metrics(sample): ...@@ -52,6 +52,7 @@ def compute_metrics(sample):
metric_result = {} metric_result = {}
for k, v in result.items(): for k, v in result.items():
metric_result[k] = round(v["f"] * 100, 4) metric_result[k] = round(v["f"] * 100, 4)
metric_result["bleu-4"] = round(bleu_score * 100, 4) metric_result["bleu-4"] = round(bleu_score * 100, 4)
return metric_result return metric_result
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,12 +11,13 @@ ...@@ -14,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import shutil import shutil
import fire import fire
from peft import PeftModel from peft import PeftModel
from transformers import AutoModel, AutoProcessor, AutoTokenizer, Qwen2_5OmniThinkerForConditionalGeneration from transformers import AutoModel, AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration # type: ignore
def merge_lora( def merge_lora(
...@@ -41,20 +39,14 @@ def merge_lora( ...@@ -41,20 +39,14 @@ def merge_lora(
save_path (str): Directory where the merged model and configurations will be saved. save_path (str): Directory where the merged model and configurations will be saved.
""" """
# 1. Load the original model, tokenizer, and processor # 1. Load the original model, tokenizer, and processor
model = AutoModel.from_pretrained(base_model_path) model = AutoModel.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained(base_model_path) processor = AutoProcessor.from_pretrained(base_model_path)
print("Successfully loaded the original model and tokenizer.")
try:
processor = AutoProcessor.from_pretrained(base_model_path)
except Exception:
print("Processor configuration not found, skipping processor load.")
processor = None
print("Successfully loaded the original model, tokenizer, and processor (if available).")
# 2. Extract the submodule to be merged (e.g., model.thinker) # 2. Extract the submodule to be merged (e.g., model.thinker)
if not hasattr(model, submodule_name): if not hasattr(model, submodule_name):
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.") raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
base_submodule = getattr(model, submodule_name) base_submodule = getattr(model, submodule_name)
print(f"Successfully extracted submodule: {submodule_name}.") print(f"Successfully extracted submodule: {submodule_name}.")
...@@ -71,11 +63,8 @@ def merge_lora( ...@@ -71,11 +63,8 @@ def merge_lora(
# 6. Save the final merged model along with the tokenizer and processor configuration # 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path) model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path) processor.save_pretrained(save_path)
if processor is not None: print(f"Merged model and tokenizer saved to {save_path}.")
processor.save_pretrained(save_path)
print(f"Merged model and configuration saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file) source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file) target_file = os.path.join(save_path, extra_file)
...@@ -89,7 +78,7 @@ def merge_lora( ...@@ -89,7 +78,7 @@ def merge_lora(
def save_full_model( def save_full_model(
saved_thinker_path: str, saved_thinker_path: str,
base_model_path: str, base_model_path: str,
save_path: str, save_path: str = "./merged_model_checkpoint",
extra_file: str = "spk_dict.pt", extra_file: str = "spk_dict.pt",
): ):
"""Load the saved thinker module and the original model, replace the thinker in the original model. """Load the saved thinker module and the original model, replace the thinker in the original model.
...@@ -99,26 +88,23 @@ def save_full_model( ...@@ -99,26 +88,23 @@ def save_full_model(
Args: Args:
saved_thinker_path (str): Path to the saved thinker weights. saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model. base_model_path (str): Directory path of the original model.
save_path (str): Directory where the final complete model will be saved. save_path (str): Directory where the merged model and configurations will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt"). extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
""" """
# Load the thinker module # 1. Load the saved thinker module and the original model
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(saved_thinker_path, device_map="cpu") thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
# Load the original model saved_thinker_path, torch_dtype="auto", device_map="cpu"
base_model = AutoModel.from_pretrained(base_model_path, device_map="cpu") )
# Replace the thinker module in the original model base_model = AutoModel.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
base_model.thinker = thinker base_model.thinker = thinker
# Load the processor and tokenizer # 2. Save the complete model along with its tokenizer and processor configuration
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True) processor = AutoProcessor.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
# Save the complete model along with its configurations
base_model.save_pretrained(save_path) base_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
processor.save_pretrained(save_path) processor.save_pretrained(save_path)
print(f"Complete model, tokenizer, and processor configuration have been saved to {save_path}.") print(f"Merged model and tokenizer saved to {save_path}.")
# 3. Copy the extra file from the base model directory to the save_path
source_file = os.path.join(base_model_path, extra_file) source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file) target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file): if os.path.exists(source_file):
......
...@@ -20,7 +20,7 @@ from transformers import Seq2SeqTrainingArguments ...@@ -20,7 +20,7 @@ from transformers import Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import check_version, get_device_count from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_vllm_available from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
...@@ -56,7 +56,6 @@ def vllm_infer( ...@@ -56,7 +56,6 @@ def vllm_infer(
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
""" """
check_version("vllm>=0.4.3,<=0.8.2")
if pipeline_parallel_size > get_device_count(): if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
......
...@@ -45,7 +45,7 @@ extra_require = { ...@@ -45,7 +45,7 @@ extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.5"],
"liger-kernel": ["liger-kernel>=0.5.5"], "liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
...@@ -53,8 +53,8 @@ extra_require = { ...@@ -53,8 +53,8 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.8.2"], "vllm": ["vllm>=0.4.3,<=0.8.4"],
"sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"], "sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"apollo": ["apollo-torch"], "apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
...@@ -74,7 +74,7 @@ extra_require = { ...@@ -74,7 +74,7 @@ extra_require = {
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"openmind": ["openmind"], "openmind": ["openmind"],
"swanlab": ["swanlab"], "swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest"], "dev": ["pre-commit", "ruff", "pytest", "build"],
} }
......
...@@ -19,10 +19,10 @@ Level: ...@@ -19,10 +19,10 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0 transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.4.1 datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.5.2 accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.0 peft>=0.14.0,<=0.15.1
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
......
...@@ -87,7 +87,8 @@ def _process_request( ...@@ -87,7 +87,8 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content
else: else:
system = None system = None
...@@ -128,7 +129,9 @@ def _process_request( ...@@ -128,7 +129,9 @@ def _process_request(
elif input_item.type == "video_url": elif input_item.type == "video_url":
text_content += VIDEO_PLACEHOLDER text_content += VIDEO_PLACEHOLDER
video_url = input_item.video_url.url video_url = input_item.video_url.url
if os.path.isfile(video_url): # local file if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(video_url): # local file
video_stream = open(video_url, "rb") video_stream = open(video_url, "rb")
else: # web uri else: # web uri
video_stream = requests.get(video_url, stream=True).raw video_stream = requests.get(video_url, stream=True).raw
...@@ -137,7 +140,9 @@ def _process_request( ...@@ -137,7 +140,9 @@ def _process_request(
elif input_item.type == "audio_url": elif input_item.type == "audio_url":
text_content += AUDIO_PLACEHOLDER text_content += AUDIO_PLACEHOLDER
audio_url = input_item.audio_url.url audio_url = input_item.audio_url.url
if os.path.isfile(audio_url): # local file if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(audio_url): # local file
audio_stream = open(audio_url, "rb") audio_stream = open(audio_url, "rb")
else: # web uri else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw audio_stream = requests.get(audio_url, stream=True).raw
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import subprocess import subprocess
import sys import sys
from copy import deepcopy
from enum import Enum, unique from enum import Enum, unique
from . import launcher from . import launcher
...@@ -96,6 +97,13 @@ def main(): ...@@ -96,6 +97,13 @@ def main():
if int(nnodes) > 1: if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run( process = subprocess.run(
( (
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
...@@ -110,7 +118,9 @@ def main(): ...@@ -110,7 +118,9 @@ def main():
file_name=launcher.__file__, file_name=launcher.__file__,
args=" ".join(sys.argv[1:]), args=" ".join(sys.argv[1:]),
) )
.split() .split(),
env=env,
check=True,
) )
sys.exit(process.returncode) sys.exit(process.returncode)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment