Commit 1bfbcff0 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1204 canceled with stages
import os
import shutil
import subprocess
import time
from collections import deque
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from queue import Queue
from typing import Any, Dict, List
import json
import torch
from swift.llm import ExportArguments
from swift.utils import get_logger
from swift.utils.torch_utils import _find_free_port
logger = get_logger()
@dataclass
class Experiment:
name: str
cmd: str
group: str
requirements: Dict = field(default_factory=dict)
eval_requirements: Dict = field(default_factory=dict)
eval_dataset: List = field(default_factory=list)
args: Dict = field(default_factory=dict)
env: Dict = field(default_factory=dict)
record: Dict = field(default_factory=dict)
create_time: float = None
runtime: Dict = field(default_factory=dict)
input_args: Any = None
do_eval = False
def __init__(self,
name,
cmd,
group,
requirements=None,
eval_requirements=None,
eval_dataset=None,
args=None,
input_args=None,
**kwargs):
self.name = name
self.cmd = cmd
self.group = group
self.requirements = requirements or {}
self.args = args or {}
self.record = {}
self.env = {}
self.runtime = {}
self.input_args = input_args
self.eval_requirements = eval_requirements or {}
self.eval_dataset = eval_dataset or []
if self.cmd == 'eval':
self.do_eval = True
def load(self, _json):
self.name = _json['name']
self.cmd = _json['cmd']
self.requirements = _json['requirements']
self.args = _json['args']
self.record = _json['record']
self.env = _json['env']
self.create_time = _json['create_time']
@property
def priority(self):
return self.requirements.get('gpu', 0)
def to_dict(self):
_dict = asdict(self)
_dict.pop('runtime')
_dict.pop('input_args')
return _dict
class ExpManager:
RESULT_FILE = 'result.jsonl'
def __init__(self):
self.exps = []
def assert_gpu_not_overlap(self):
all_gpus = set()
for exp in self.exps:
gpus = exp.runtime['env']['CUDA_VISIBLE_DEVICES'].split(',')
if all_gpus & set(gpus):
raise ValueError(f'GPU overlap: {self.exps}!')
all_gpus.update(gpus)
def run(self, exp: Experiment):
if os.path.exists(os.path.join(exp.input_args.save_dir, exp.name + '.json')):
with open(os.path.join(exp.input_args.save_dir, exp.name + '.json'), 'r') as f:
_json = json.load(f)
if exp.eval_dataset and 'eval_result' not in _json['record']:
if not exp.do_eval:
logger.info(f'Experiment {exp.name} need eval, load from file.')
exp.load(_json)
exp.do_eval = True
else:
logger.warn(f'Experiment {exp.name} already done, skip')
return
if exp.do_eval:
runtime = self._build_eval_cmd(exp)
exp.runtime = runtime
envs = deepcopy(runtime.get('env', {}))
envs.update(os.environ)
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
os.makedirs('exp', exist_ok=True)
log_file = os.path.join('exp', f'{exp.name}.eval.log')
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
self.exps.append(exp)
self.assert_gpu_not_overlap()
return
if any([exp.name == e.name for e in self.exps]):
raise ValueError(f'Why exp name duplicate? {exp.name}')
elif exp.cmd == 'export' and any([exp.cmd == 'export' for exp in self.exps]): # noqa
raise AssertionError('Cannot run parallel export task.')
else:
exp.create_time = time.time()
runtime = self._build_cmd(exp)
exp.runtime = runtime
envs = deepcopy(runtime.get('env', {}))
envs.update(os.environ)
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
os.makedirs('exp', exist_ok=True)
log_file = os.path.join('exp', f'{exp.name}.{exp.cmd}.log')
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
self.exps.append(exp)
self.assert_gpu_not_overlap()
def _build_eval_cmd(self, exp: Experiment):
gpu = exp.eval_requirements.get('gpu', None)
env = {}
allocated = []
if gpu:
allocated = self._find_free_gpu(int(gpu))
assert allocated, 'No free gpu for now!'
allocated = [str(gpu) for gpu in allocated]
env['CUDA_VISIBLE_DEVICES'] = ','.join(allocated)
best_model_checkpoint = exp.record.get('best_model_checkpoint')
eval_dataset = exp.eval_dataset
if best_model_checkpoint is not None:
model_type_kwargs = ''
if not os.path.exists(os.path.join(best_model_checkpoint, 'sft_args.json')):
model_type = best_model_checkpoint[best_model_checkpoint.rfind(os.path.sep) + 1:]
model_type = '-'.join(model_type.split('-')[:-2])
model_type_kwargs = f'--model_type {model_type}'
cmd = f'swift eval {model_type_kwargs} --ckpt_dir {best_model_checkpoint} ' \
+ f'--infer_backend pt --sft_type full --name {exp.name} --eval_dataset {" ".join(eval_dataset)}'
else:
assert exp.args.get('model_type') is not None
cmd = f'swift eval --model_type {exp.args.get("model_type")} --infer_backend pt ' \
f'--name {exp.name} --eval_dataset {" ".join(eval_dataset)}'
return {
'running_cmd': cmd,
'gpu': allocated,
'env': env,
}
def _build_cmd(self, exp: Experiment):
gpu = exp.requirements.get('gpu', None)
env = {}
allocated = []
if gpu:
allocated = self._find_free_gpu(int(gpu))
assert allocated, 'No free gpu for now!'
allocated = [str(gpu) for gpu in allocated]
env['CUDA_VISIBLE_DEVICES'] = ','.join(allocated)
if int(exp.requirements.get('ddp', 1)) > 1:
env['NPROC_PER_NODE'] = exp.requirements.get('ddp')
env['MASTER_PORT'] = str(_find_free_port())
if exp.cmd == 'sft':
from swift.llm import SftArguments
args = exp.args
sft_args = SftArguments(**args)
args['output_dir'] = sft_args.output_dir
args['logging_dir'] = sft_args.logging_dir
args['add_output_dir_suffix'] = False
os.makedirs(sft_args.output_dir, exist_ok=True)
os.makedirs(sft_args.logging_dir, exist_ok=True)
cmd = 'swift sft '
for key, value in args.items():
cmd += f' --{key} {value}'
elif exp.cmd == 'dpo':
from swift.llm import DPOArguments
args = exp.args
dpo_args = DPOArguments(**args)
args['output_dir'] = dpo_args.output_dir
args['logging_dir'] = dpo_args.logging_dir
args['add_output_dir_suffix'] = False
os.makedirs(dpo_args.output_dir, exist_ok=True)
os.makedirs(dpo_args.logging_dir, exist_ok=True)
cmd = 'swift dpo '
for key, value in args.items():
cmd += f' --{key} {value}'
elif exp.cmd == 'export':
args = exp.args
cmd = 'swift export '
for key, value in args.items():
cmd += f' --{key} {value}'
else:
raise ValueError(f'Unsupported cmd type: {exp.cmd}')
return {
'running_cmd': cmd,
'gpu': allocated,
'env': env,
'logging_dir': args.get('logging_dir'),
'output_dir': args.get('output_dir', args.get('ckpt_dir'))
}
def _find_free_gpu(self, n):
all_gpus = set()
for exp in self.exps:
all_gpus.update(exp.runtime.get('gpu', set()))
all_gpus = {int(g) for g in all_gpus}
free_gpu = set(range(torch.cuda.device_count())) - all_gpus
if len(free_gpu) < n:
return None
return list(free_gpu)[:n]
def prepare_experiments(self, args: Any):
experiments = []
for config_file in args.config:
with open(config_file, 'r') as f:
group = os.path.basename(config_file)
group = group[:-5]
content = json.load(f)
exps = content['experiment']
for exp in exps:
main_cfg = deepcopy(content)
name = exp['name']
cmd = main_cfg['cmd']
run_args = main_cfg['args']
env = main_cfg.get('env', {})
requirements = main_cfg.get('requirements', {})
eval_requirements = main_cfg.get('eval_requirements', {})
eval_dataset = main_cfg.get('eval_dataset', {})
if 'args' in exp:
run_args.update(exp['args'])
if 'requirements' in exp:
requirements.update(exp['requirements'])
if 'env' in exp:
env.update(exp['env'])
experiments.append(
Experiment(
group=group,
name=name,
cmd=cmd,
args=run_args,
env=env,
requirements=requirements,
eval_requirements=eval_requirements,
eval_dataset=eval_dataset,
input_args=args))
return experiments
@staticmethod
def _get_metric(exp: Experiment):
if exp.do_eval:
if os.path.isfile(os.path.join('exp', f'{exp.name}.eval.log')):
with open(os.path.join('exp', f'{exp.name}.eval.log'), 'r') as f:
for line in f.readlines():
if 'Final report:' in line:
return json.loads(line.split('Final report:')[1].replace('\'', '"'))
elif exp.cmd == 'export':
exp_args = ExportArguments(**exp.args)
if exp_args.quant_bits > 0:
if exp_args.ckpt_dir is None:
path = f'{exp_args.model_type}-{exp_args.quant_method}-int{exp_args.quant_bits}'
else:
ckpt_dir, ckpt_name = os.path.split(exp_args.ckpt_dir)
path = os.path.join(ckpt_dir, f'{ckpt_name}-{exp_args.quant_method}-int{exp_args.quant_bits}')
else:
ckpt_dir, ckpt_name = os.path.split(exp_args.ckpt_dir)
path = os.path.join(ckpt_dir, f'{ckpt_name}-merged')
if os.path.exists(path):
shutil.rmtree(exp.name, ignore_errors=True)
os.makedirs(exp.name, exist_ok=True)
shutil.move(path, os.path.join(exp.name, path))
return {
'best_model_checkpoint': os.path.join(exp.name, path),
}
else:
logging_dir = exp.runtime.get('logging_dir')
logging_file = os.path.join(logging_dir, '..', 'logging.jsonl')
if os.path.isfile(logging_file):
with open(logging_file, 'r') as f:
for line in f.readlines():
if 'model_info' in line:
return json.loads(line)
return None
@staticmethod
def write_record(exp: Experiment):
target_dir = exp.input_args.save_dir
file = os.path.join(target_dir, exp.name + '.json')
with open(file, 'w', encoding='utf-8') as f:
f.write(json.dumps(exp.to_dict()) + '\n')
def _poll(self):
while True:
time.sleep(5)
has_finished = False
for exp in self.exps:
rt = exp.handler.poll()
if rt is None:
continue
has_finished = True
if rt == 0:
if not exp.do_eval:
all_metric = self._get_metric(exp)
if all_metric:
exp.record.update(all_metric)
if exp.eval_dataset:
exp.do_eval = True
self.exp_queue.appendleft(exp)
self.write_record(exp)
else:
logger.error(f'Running {exp.name} task, but no result found')
else:
all_metric = self._get_metric(exp)
exp.record['eval_result'] = all_metric
if all_metric:
self.write_record(exp)
else:
logger.error(f'Running {exp.name} eval task, but no eval result found')
logger.info(f'Running {exp.name} finished with return code: {rt}')
if has_finished:
self.exps = [exp for exp in self.exps if exp.handler.poll() is None]
break
def begin(self, args: Any):
exps = self.prepare_experiments(args)
logger.info(f'all exps: {exps}')
exps.sort(key=lambda e: e.priority)
self.exp_queue = deque()
for exp in exps:
self.exp_queue.append(exp)
while len(self.exp_queue) or len(self.exps) > 0:
while len(self.exp_queue):
try:
logger.info(f'Running exp: {self.exp_queue[0].name}')
self.run(self.exp_queue[0])
except Exception as e:
if not isinstance(e, AssertionError):
logger.error(f'Adding exp {self.exp_queue[0].name} error because of:')
logger.error(e)
self.exp_queue.popleft()
else:
logger.info(f'Adding exp {self.exp_queue[0].name} error because of:', str(e))
if 'no free gpu' in str(e).lower():
break
else:
continue
else:
self.exp_queue.popleft()
self._poll()
logger.info(f'Run task finished because of exp queue: {self.exp_queue} and exps: {self.exps}')
def find_all_config(dir_or_file: str):
if os.path.isfile(dir_or_file):
return [dir_or_file]
else:
configs = []
for dirpath, dirnames, filenames in os.walk(dir_or_file):
for name in filenames:
if name.endswith('.json') and 'ipynb' not in dirpath:
configs.append(os.path.join(dirpath, name))
return configs
# Copyright (c) Alibaba, Inc. and its affiliates.
import dataclasses
import os
from dataclasses import dataclass
from typing import Any, Dict, List
import json
import numpy as np
from swift.utils.utils import split_str_parts_by
@dataclass
class ModelOutput:
group: str = None
name: str = None
cmd: str = None
requirements: Dict[str, str] = dataclasses.field(default_factory=dict)
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
memory: str = None
train_time: float = None
train_samples: int = None
train_samples_per_second: float = None
last_model_checkpoint: str = None
best_model_checkpoint: str = None
best_metric: Any = None
global_step: int = None
num_total_parameters: float = None
num_trainable_parameters: float = None
num_buffers: float = None
trainable_parameters_percentage: float = None
train_dataset_info: str = None
val_dataset_info: str = None
train_create_time: float = None
eval_tokens: int = None
eval_time: float = None
reports: Dict[str, Any] = None
train_loss: float = None
@property
def tuner_hyper_params(self):
hyper_params = ''
args = self.args
if 'sft_type' not in args:
return ''
if args['sft_type'] in ('lora', 'adalora', 'longlora'):
if 'lora_rank' in args:
hyper_params += f'rank={args["lora_rank"]}/' \
f'target={args["lora_target_modules"]}/' \
f'alpha={args["lora_alpha"]}/' \
f'lr_ratio={args.get("lora_lr_ratio", None)}/' \
f'use_rslora={args.get("use_rslora", False)}/' \
f'use_dora={args.get("use_dora", False)}'
else:
hyper_params = ''
if args['sft_type'] == 'full':
if 'use_galore' in args and args['use_galore'] == 'true':
hyper_params += f'galore_rank={args["galore_rank"]}/' \
f'galore_per_parameter={args["galore_optim_per_parameter"]}/' \
f'galore_with_embedding={args["galore_with_embedding"]}/'
if args['sft_type'] == 'llamapro':
hyper_params += f'num_blocks={args["llamapro_num_new_blocks"]}/'
if 'neftune_noise_alpha' in args and args['neftune_noise_alpha']:
hyper_params += f'neftune_noise_alpha={args["neftune_noise_alpha"]}/'
if hyper_params.endswith('/'):
hyper_params = hyper_params[:-1]
return hyper_params
@property
def hyper_paramters(self):
if 'learning_rate' not in self.args:
return ''
return f'lr={self.args["learning_rate"]}/' \
f'epoch={self.args["num_train_epochs"]}'
@property
def train_speed(self):
if self.train_samples_per_second:
return f'{self.train_samples_per_second:.2f}({self.train_samples} samples/{self.train_time:.2f} seconds)'
else:
return ''
@property
def infer_speed(self):
if self.eval_tokens:
return f'{self.eval_tokens / self.eval_time:.2f}({self.eval_tokens} tokens/{self.eval_time:.2f} seconds)'
return ''
def generate_sft_report(outputs: List[ModelOutput]):
gsm8k_accs = []
arc_accs = []
ceval_accs = []
for output in outputs:
gsm8k_acc = None
arc_acc = None
ceval_acc = None
for report in (output.reports or []):
if report['name'] == 'gsm8k':
gsm8k_acc = report['score']
if report['name'] == 'arc':
arc_acc = report['score']
if report['name'] == 'ceval':
ceval_acc = report['score']
gsm8k_accs.append(gsm8k_acc)
arc_accs.append(arc_acc)
ceval_accs.append(ceval_acc)
tab = '| exp_name | model_type | dataset | ms-bench mix ratio | tuner | tuner_params | trainable params(M) | flash_attn | gradient_checkpointing | hypers | memory | train speed(samples/s) | infer speed(tokens/s) | train_loss | eval_loss | gsm8k weighted acc | arc weighted acc | ceval weighted acc |\n' \
'| -------- | ---------- | ------- | -------------------| ----- | ------------ | ------------------- | -----------| ---------------------- | ------ | ------ | ---------------------- | --------------------- | ---------- | --------- | ------------------ | ---------------- | ------------------ |\n' # noqa
min_best_metric = 999.
min_train_loss = 999.
if outputs:
min_best_metric = min([output.best_metric or 999. for output in outputs])
min_train_loss = min([output.train_loss or 999. for output in outputs])
max_gsm8k = 0.0
if gsm8k_accs:
max_gsm8k = max([gsm8k or 0. for gsm8k in gsm8k_accs])
max_arc = 0.0
if arc_accs:
max_arc = max([arc or 0. for arc in arc_accs])
max_ceval = 0.0
if ceval_accs:
max_ceval = max([ceval or 0. for ceval in ceval_accs])
for output, gsm8k_acc, arc_acc, ceval_acc in zip(outputs, gsm8k_accs, arc_accs, ceval_accs):
use_flash_attn = output.args.get('use_flash_attn', '')
use_gc = output.args.get('gradient_checkpointing', '')
memory = output.memory
train_speed = output.train_speed
infer_speed = output.infer_speed
is_best_metric = np.isclose(min_best_metric, output.best_metric or 999.0)
is_best_loss = np.isclose(min_train_loss, output.train_loss or 999.0)
is_best_gsm8k = np.isclose(max_gsm8k, gsm8k_acc or 0.0)
is_best_arc = np.isclose(max_arc, arc_acc or 0.0)
is_best_ceval = np.isclose(max_ceval, ceval_acc or 0.0)
if not is_best_metric:
best_metric = '' if not output.best_metric else f'{output.best_metric:.2f}'
else:
best_metric = '' if not output.best_metric else f'**{output.best_metric:.2f}**'
if not is_best_loss:
train_loss = '' if not output.train_loss else f'{output.train_loss:.2f}'
else:
train_loss = '' if not output.train_loss else f'**{output.train_loss:.2f}**'
if not is_best_gsm8k:
gsm8k_acc = '' if not gsm8k_acc else f'{gsm8k_acc:.3f}'
else:
gsm8k_acc = '' if not gsm8k_acc else f'**{gsm8k_acc:.3f}**'
if not is_best_arc:
arc_acc = '' if not arc_acc else f'{arc_acc:.3f}'
else:
arc_acc = '' if not arc_acc else f'**{arc_acc:.3f}**'
if not is_best_ceval:
ceval_acc = '' if not ceval_acc else f'{ceval_acc:.3f}'
else:
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'
line = f'|{output.name}|' \
f'{output.args["model_type"]}|' \
f'{output.args.get("dataset")}|' \
f'{output.args.get("train_dataset_mix_ratio", 0.)}|' \
f'{output.args.get("sft_type")}|' \
f'{output.tuner_hyper_params}|' \
f'{output.num_trainable_parameters}({output.trainable_parameters_percentage})|' \
f'{use_flash_attn}|' \
f'{use_gc}|' \
f'{output.hyper_paramters}|' \
f'{memory}|' \
f'{train_speed}|' \
f'{infer_speed}|' \
f'{best_metric}|' \
f'{train_loss}|' \
f'{gsm8k_acc}|' \
f'{arc_acc}|' \
f'{ceval_acc}|\n'
tab += line
return tab
def generate_export_report(outputs: List[ModelOutput]):
tab = '| exp_name | model_type | calibration dataset | quantization method | quantization bits | infer speed(tokens/s) | gsm8k weighted acc | arc weighted acc | ceval weighted acc |\n' \
'| -------- | ---------- | ------------------- | ------------------- | ----------------- | --------------------- | ------------------ | ---------------- | ------------------ |\n' # noqa
gsm8k_accs = []
arc_accs = []
ceval_accs = []
for output in outputs:
gsm8k_acc = None
arc_acc = None
ceval_acc = None
for report in (output.reports or []):
if report['name'] == 'gsm8k':
gsm8k_acc = report['score']
if report['name'] == 'arc':
arc_acc = report['score']
if report['name'] == 'ceval':
ceval_acc = report['score']
gsm8k_accs.append(gsm8k_acc)
arc_accs.append(arc_acc)
ceval_accs.append(ceval_acc)
max_gsm8k = 0.0
if gsm8k_accs:
max_gsm8k = max([gsm8k or 0. for gsm8k in gsm8k_accs])
max_arc = 0.0
if arc_accs:
max_arc = max([arc or 0. for arc in arc_accs])
max_ceval = 0.0
if ceval_accs:
max_ceval = max([ceval or 0. for ceval in ceval_accs])
for output, gsm8k_acc, arc_acc, ceval_acc in zip(outputs, gsm8k_accs, arc_accs, ceval_accs):
infer_speed = output.infer_speed
is_best_gsm8k = np.isclose(max_gsm8k, gsm8k_acc or 0.0)
is_best_arc = np.isclose(max_arc, arc_acc or 0.0)
is_best_ceval = np.isclose(max_ceval, ceval_acc or 0.0)
if not is_best_gsm8k:
gsm8k_acc = '' if not gsm8k_acc else f'{gsm8k_acc:.3f}'
else:
gsm8k_acc = '' if not gsm8k_acc else f'**{gsm8k_acc:.3f}**'
if not is_best_arc:
arc_acc = '' if not arc_acc else f'{arc_acc:.3f}'
else:
arc_acc = '' if not arc_acc else f'**{arc_acc:.3f}**'
if not is_best_ceval:
ceval_acc = '' if not ceval_acc else f'{ceval_acc:.3f}'
else:
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'
if output.train_dataset_info:
dataset_info = f'{output.args["dataset"]}/{output.train_dataset_info}'
else:
dataset_info = f'{output.args["dataset"]}'
line = f'|{output.name}|' \
f'{output.args["model_type"]}|' \
f'{dataset_info}|' \
f'{output.args["quant_method"]}|' \
f'{output.args["quant_bits"]}|' \
f'{infer_speed}|' \
f'{gsm8k_acc}|' \
f'{arc_acc}|' \
f'{ceval_acc}|\n'
tab += line
return tab
def parse_output(file):
with open(file, 'r') as f:
content = json.load(f)
name = content['name']
group = content['group']
cmd = content['cmd']
requirements = content['requirements']
args = content['args']
create_time = float(content.get('create_time') or 0)
content = content['record']
if cmd == 'export':
best_model_checkpoint = content['best_model_checkpoint']
eval_tokens = 0
eval_time = 0.0
eval_result = None
if 'eval_result' in content:
eval_result = content['eval_result']
eval_tokens = eval_result['generation_info']['tokens']
eval_time = eval_result['generation_info']['time']
eval_result = eval_result['report']
return ModelOutput(
group=group,
name=name,
cmd=cmd,
requirements=requirements,
args=args,
best_model_checkpoint=best_model_checkpoint,
eval_time=eval_time,
eval_tokens=eval_tokens,
reports=eval_result,
)
else:
memory = None
train_time = None
train_samples = None
train_samples_per_second = None
last_model_checkpoint = None
best_model_checkpoint = None
best_metric = None
global_step = None
train_dataset_info = None
val_dataset_info = None
num_trainable_parameters = None
num_buffers = None
trainable_parameters_percentage = None
num_total_parameters = None
train_loss = None
if 'memory' in content:
memory = content['memory']
memory = '/'.join(memory.values())
if 'train_time' in content:
train_time = content['train_time']['train_runtime']
train_samples = content['train_time']['n_train_samples']
train_samples_per_second = content['train_time']['train_samples_per_second']
if 'last_model_checkpoint' in content:
last_model_checkpoint = content['last_model_checkpoint']
if 'best_model_checkpoint' in content:
best_model_checkpoint = content['best_model_checkpoint']
if 'best_metric' in content:
best_metric = content['best_metric']
if 'log_history' in content:
train_loss = content['log_history'][-1]['train_loss']
if 'global_step' in content:
global_step = content['global_step']
if 'dataset_info' in content:
train_dataset_info = content['dataset_info'].get('train_dataset')
val_dataset_info = content['dataset_info'].get('val_dataset')
if 'model_info' in content:
# model_info like: SwiftModel: 6758.4041M Params (19.9885M Trainable [0.2958%]), 16.7793M Buffers.
str_dict = split_str_parts_by(content['model_info'], [
'SwiftModel:', 'CausalLM:', 'Seq2SeqLM:', 'LMHeadModel:', 'M Params (', 'M Trainable [', ']), ',
'M Buffers.'
])
str_dict = {c['key']: c['content'] for c in str_dict}
if 'SwiftModel:' in str_dict:
num_total_parameters = float(str_dict['SwiftModel:'])
elif 'CausalLM:' in str_dict:
num_total_parameters = float(str_dict['CausalLM:'])
elif 'Seq2SeqLM:' in str_dict:
num_total_parameters = float(str_dict['Seq2SeqLM:'])
elif 'LMHeadModel:' in str_dict:
num_total_parameters = float(str_dict['LMHeadModel:'])
num_trainable_parameters = float(str_dict['M Params ('])
num_buffers = float(str_dict[']), '])
trainable_parameters_percentage = str_dict['M Trainable [']
eval_tokens = 0
eval_time = 0.0
eval_result = None
if 'eval_result' in content:
eval_result = content['eval_result']
eval_tokens = eval_result['generation_info']['tokens']
eval_time = eval_result['generation_info']['time']
eval_result = eval_result['report']
return ModelOutput(
group=group,
name=name,
cmd=cmd,
requirements=requirements,
args=args,
memory=memory,
train_time=train_time,
train_samples=train_samples,
train_samples_per_second=train_samples_per_second,
last_model_checkpoint=last_model_checkpoint,
best_model_checkpoint=best_model_checkpoint,
best_metric=best_metric,
global_step=global_step,
train_dataset_info=train_dataset_info,
val_dataset_info=val_dataset_info,
train_create_time=create_time,
num_total_parameters=num_total_parameters,
num_trainable_parameters=num_trainable_parameters,
num_buffers=num_buffers,
trainable_parameters_percentage=trainable_parameters_percentage,
eval_time=eval_time,
eval_tokens=eval_tokens,
reports=eval_result,
train_loss=train_loss,
)
def generate_reports():
outputs = []
for dirs, _, files in os.walk('./experiment'):
for file in files:
abs_file = os.path.join(dirs, file)
if not abs_file.endswith('.json') or 'ipynb' in abs_file:
continue
outputs.append(parse_output(abs_file))
all_groups = set([output.group for output in outputs])
for group in all_groups:
group_outputs = [output for output in outputs if output.group == group]
print(f'=================Printing the sft cmd result of exp {group}==================\n\n')
print(generate_sft_report([output for output in group_outputs if output.cmd in ('sft', 'eval')]))
# print(f'=================Printing the dpo result of exp {group}==================')
# print(generate_dpo_report([output for output in outputs if output.cmd == 'dpo']))
print(f'=================Printing the export cmd result of exp {group}==================\n\n')
print(generate_export_report([output for output in group_outputs if output.cmd == 'export']))
print('=================Printing done==================\n\n')
if __name__ == '__main__':
generate_reports()
# CUDA_VISIBLE_DEVICES=0 nohup python scripts/benchmark/test_memory_time/run_loop.py &> 0.out &
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import subprocess
from typing import List
from swift.utils import read_from_jsonl, write_to_jsonl
def test_memory_time_loop(train_kwargs_jsonl: str) -> None:
while True:
obj_list = read_from_jsonl(train_kwargs_jsonl)
if len(obj_list[0]) == 0:
break
obj: List[str] = obj_list.pop(0)
obj_list.append(obj)
write_to_jsonl(train_kwargs_jsonl, obj_list)
ret = subprocess.run(['python', 'scripts/benchmark/test_memory_time/run_single.py', *obj])
assert ret.returncode == 0
if __name__ == '__main__':
jsonl_path = os.path.join('scripts/benchmark/test_memory_time/run.jsonl')
test_memory_time_loop(jsonl_path)
import time
from dataclasses import dataclass, field
from typing import *
import numpy as np
import torch
from swift.llm import sft_main
from swift.llm.utils import *
from swift.utils import *
@dataclass
class TrainArguments(SftArguments):
run_time: int = 1
global_seed: int = 42
def __post_init__(self):
if self.model_type is None:
self.model_type = 'qwen-7b-chat'
if self.use_flash_attn is None:
self.use_flash_attn = True
return
def get_non_default_args(train_args) -> Dict[str, Any]:
train_args_default = train_args.__class__()
res = {}
for k, v in train_args.__dict__.items():
v_default = getattr(train_args_default, k)
if v != v_default or k in {'use_flash_attn', 'model_type'}:
res[k] = v
return res
def test_memory_time(train_args: TrainArguments) -> Dict[str, Dict[str, Any]]:
random_state = np.random.RandomState(train_args.global_seed)
args_kwargs = get_non_default_args(train_args)
print(f'args_kwargs: {args_kwargs}')
train_dataset_sample = 1000 # save time
if args_kwargs.get('max_length', 2048) <= 2048:
train_dataset_sample = -1
for i in range(train_args.run_time):
sft_args = SftArguments(
dataset_test_ratio=0,
dataset=DatasetName.cls_fudan_news_zh,
train_dataset_sample=train_dataset_sample,
save_strategy='no',
check_dataset_strategy='warning',
seed=get_seed(random_state),
**args_kwargs)
output = sft_main(sft_args)
torch.cuda.empty_cache()
res = {
'samples/s': f"{output['train_time']['train_samples_per_second']:.2f}",
'memory': output['memory'],
'train_args': check_json_format(args_kwargs),
'model_info': output['model_info'],
'dataset_info': output['dataset_info']
}
append_to_jsonl('scripts/benchmark/test_memory_time/result.jsonl', res)
print(res)
return res
test_memory_time_main = get_main(TrainArguments, test_memory_time)
if __name__ == '__main__':
test_memory_time_main()
import os
import re
import torch
from modelscope import snapshot_download
from swift.llm import MODEL_MAPPING
def test_readme():
for model_type in MODEL_MAPPING.keys():
model_id = MODEL_MAPPING[model_type]['model_id_or_path']
model_dir = snapshot_download(model_id, revision='master')
readme_path = os.path.join(model_dir, 'README.md')
assert os.path.exists(readme_path)
with open(readme_path, 'r') as f:
text = f.read()
code_list = re.findall(r'```python\n(.+?)\n```', text, re.M | re.S)
print(f'model_type: {model_type}')
for code in code_list:
if 'import' not in code or 'modelscope' not in code:
continue
try:
exec(code)
except Exception:
print(code)
input('[ENTER')
torch.cuda.empty_cache()
if __name__ == '__main__':
test_readme()
import os
import subprocess
from swift.llm import ModelType
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if __name__ == '__main__':
model_name_list = ModelType.get_model_name_list()
success_model_list = []
fpath = os.path.join(os.path.dirname(__file__), 'utils.py')
for model_name in model_name_list:
code = subprocess.run(['python', fpath, '--model_type', model_name])
if code.returncode == 0:
success_model_list.append(model_name)
else:
print(f'model_name: {model_name} not support vllm.')
print(success_model_list)
from dataclasses import dataclass
from swift.llm import get_default_template_type, get_template, get_vllm_engine, inference_vllm
from swift.utils import get_main
@dataclass
class VLLMTestArgs:
model_type: str
def test_vllm(args: VLLMTestArgs) -> None:
model_type = args.model_type
llm_engine = get_vllm_engine(model_type)
template_type = get_default_template_type(model_type)
template = get_template(template_type, llm_engine.hf_tokenizer)
llm_engine.generation_config.max_new_tokens = 256
request_list = [{'query': '你好!'}, {'query': '浙江的省会在哪?'}]
resp_list = inference_vllm(llm_engine, template, request_list)
for request, resp in zip(request_list, resp_list):
print(f"query: {request['query']}")
print(f"response: {resp['response']}")
test_vllm_main = get_main(VLLMTestArgs, test_vllm)
if __name__ == '__main__':
test_vllm_main()
import os
from swift.utils import plot_images
ckpt_dir = 'output/xxx/vx-xxx'
if __name__ == '__main__':
images_dir = os.path.join(ckpt_dir, 'images')
tb_dir = os.path.join(ckpt_dir, 'runs')
plot_images(images_dir, tb_dir, ['train/loss'], 0.9)
import os
from datasets import concatenate_datasets
from swift.llm import (DATASET_MAPPING, DatasetName, ModelType, dataset_map, get_dataset, get_default_template_type,
get_model_tokenizer, get_template)
from swift.utils import stat_array
def write_dataset_info() -> None:
fpaths = ['docs/source/LLM/支持的模型和数据集.md', 'docs/source_en/LLM/Supported-models-datasets.md']
pre_texts = []
for fpath in fpaths:
if os.path.exists(fpath):
with open(fpath, 'r', encoding='utf-8') as f:
text = f.read()
idx = text.find('| Dataset Name |')
pre_texts.append(text[:idx])
text = text[idx:]
text_list = [t for t in text.split('\n') if len(t.strip()) > 0]
else:
text_list = []
pre_texts.append('')
res_text_list = []
res_text_list.append(
'| Dataset Name | Dataset ID | Subsets | Dataset Size | Statistic (token) | Tags | HF Dataset ID |')
res_text_list.append(
'| ------------ | ---------- | ------- |------------- | ----------------- | ---- | ------------- |')
if len(text_list) >= 2:
text_list = text_list[2:]
else:
text_list = []
ignore_dataset = {text.split('|', 2)[1].lstrip('🔥 '): text for text in text_list}
dataset_name_list = DatasetName.get_dataset_name_list()
mapping = {}
_iter = zip(
['llm', 'vision', 'audio'],
[ModelType.qwen_7b_chat, ModelType.qwen_vl_chat, ModelType.qwen_audio_chat],
)
try:
for task_type, model_type in _iter:
_, tokenizer = get_model_tokenizer(model_type, load_model=False)
template_type = get_default_template_type(model_type)
template = get_template(template_type, tokenizer)
mapping[task_type] = template
for dataset_name in dataset_name_list:
dataset_info = DATASET_MAPPING[dataset_name]
tags = dataset_info.get('tags', [])
subsets = dataset_info.get('subsets', [])
subsets = '<br>'.join(subsets)
if 'audio' in tags:
template = mapping['audio']
elif 'vision' in tags:
template = mapping['vision']
else:
template = mapping['llm']
if dataset_name in ignore_dataset:
dataset_size, stat_str = ignore_dataset[dataset_name].split('|')[4:6]
else:
train_dataset, val_dataset = get_dataset([dataset_name],
model_name=['小黄', 'Xiao Huang'],
model_author=['魔搭', 'ModelScope'])
dataset_size = len(train_dataset)
assert val_dataset is None
raw_dataset = train_dataset
if val_dataset is not None:
raw_dataset = concatenate_datasets([raw_dataset, val_dataset])
if len(raw_dataset) < 5000:
num_proc = 1
else:
num_proc = 4
dataset = dataset_map(raw_dataset, template.encode, num_proc=num_proc)
_token_len = []
input_ids = dataset['input_ids']
for i in range(len(dataset)):
_token_len.append(len(input_ids[i]))
stat = stat_array(_token_len)[0]
stat_str = f"{stat['mean']:.1f}±{stat['std']:.1f}, min={stat['min']}, max={stat['max']}"
ms_url = f"https://modelscope.cn/datasets/{dataset_info['dataset_id_or_path']}/summary"
if '🔥' in tags:
tags.remove('🔥')
dataset_name = '🔥' + dataset_name
tags_str = ', '.join(tags)
if len(tags_str) == 0:
tags_str = '-'
hf_dataset_id = dataset_info.get('hf_dataset_id')
if hf_dataset_id is None:
hf_dataset_id = '-'
hf_dataset_id_str = '-'
else:
hf_url = f'https://huggingface.co/datasets/{hf_dataset_id}'
hf_dataset_id_str = f'[{hf_dataset_id}]({hf_url})'
res_text_list.append(f"|{dataset_name}|[{dataset_info['dataset_id_or_path']}]({ms_url})|{subsets}|"
f'{dataset_size}|{stat_str}|{tags_str}|{hf_dataset_id_str}|')
finally:
print(f'数据集总数: {len(dataset_name_list)}')
for idx in range(len(fpaths)):
text = '\n'.join(res_text_list)
text = pre_texts[idx] + text + '\n'
with open(fpaths[idx], 'w', encoding='utf-8') as f:
f.write(text)
if __name__ == '__main__':
write_dataset_info()
from typing import List
from swift.llm import MODEL_MAPPING, ModelType
def get_model_info_table() -> List[str]:
fpaths = ['docs/source/LLM/支持的模型和数据集.md', 'docs/source_en/LLM/Supported-models-datasets.md']
end_words = [['### 多模态大模型', '## 数据集'], ['### MLLM', '## Datasets']]
model_name_list = ModelType.get_model_name_list()
result = [
'| Model Type | Model ID | Default Lora Target Modules | Default Template |'
' Support Flash Attn | Support VLLM | Requires | Tags | HF Model ID |\n'
'| --------- | -------- | --------------------------- | ---------------- |'
' ------------------ | ------------ | -------- | ---- | ----------- |\n'
] * 2
res_llm: List[str] = []
res_mllm: List[str] = []
bool_mapping = {True: '&#x2714;', False: '&#x2718;'}
for model_name in model_name_list:
model_info = MODEL_MAPPING[model_name]
model_id = model_info['model_id_or_path']
lora_target_modules = ', '.join(model_info['lora_target_modules'])
template = model_info['template']
support_flash_attn = model_info.get('support_flash_attn', False)
support_flash_attn = bool_mapping[support_flash_attn]
support_vllm = model_info.get('support_vllm', False)
support_vllm = bool_mapping[support_vllm]
requires = ', '.join(model_info['requires'])
tags = model_info.get('tags', [])
if 'multi-modal' in tags:
tags.remove('multi-modal')
is_multi_modal = True
else:
is_multi_modal = False
tags_str = ', '.join(tags)
if len(tags_str) == 0:
tags_str = '-'
hf_model_id = model_info.get('hf_model_id')
if hf_model_id is None:
hf_model_id = '-'
r = [
model_name, model_id, lora_target_modules, template, support_flash_attn, support_vllm, requires, tags_str,
hf_model_id
]
if is_multi_modal:
res_mllm.append(r)
else:
res_llm.append(r)
print(f'LLM总数: {len(res_llm)}, MLLM总数: {len(res_mllm)}')
text = ['', ''] # llm, mllm
for i, res in enumerate([res_llm, res_mllm]):
for r in res:
ms_url = f'https://modelscope.cn/models/{r[1]}/summary'
if r[8] != '-':
hf_url = f'https://huggingface.co/{r[8]}'
hf_model_id_str = f'[{r[8]}]({hf_url})'
else:
hf_model_id_str = '-'
text[i] += f'|{r[0]}|[{r[1]}]({ms_url})|{r[2]}|{r[3]}|{r[4]}|{r[5]}|{r[6]}|{r[7]}|{hf_model_id_str}|\n'
result[i] += text[i]
for i, fpath in enumerate(fpaths):
with open(fpath, 'r') as f:
text = f.read()
llm_start_idx = text.find('| Model Type |')
mllm_start_idx = text[llm_start_idx + 1:].find('| Model Type |') + llm_start_idx + 1
llm_end_idx = text.find(end_words[i][0])
mllm_end_idx = text.find(end_words[i][1])
output = text[:llm_start_idx] + result[0] + '\n\n' + text[llm_end_idx:mllm_start_idx] + result[
1] + '\n\n' + text[mllm_end_idx:]
with open(fpath, 'w') as f:
text = f.write(output)
return res
if __name__ == '__main__':
get_model_info_table()
from swift.llm import TemplateType
if __name__ == '__main__':
template_name_list = TemplateType.get_template_name_list()
tn_gen = ', '.join([tn for tn in template_name_list if 'generation' in tn])
tn_chat = ', '.join([tn for tn in template_name_list if 'generation' not in tn])
print(f'Text Generation: {tn_gen}')
print(f'Chat: {tn_chat}')
[isort]
line_length = 120
multi_line_output = 0
known_standard_library = setuptools
known_first_party = swift
known_third_party = json,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
[yapf]
BASED_ON_STYLE = pep8
COLUMN_LIMIT = 120
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
SPLIT_BEFORE_ARITHMETIC_OPERATOR = true
[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
[flake8]
max-line-length = 120
select = B,C,E,F,P,T4,W,B9
ignore = F401,F403,F405,F821,W503,E251,W504,E126
exclude = docs/src,*.pyi,.git,peft.py
[darglint]
ignore=DAR101
[easy_install]
index-url=https://pypi.tuna.tsinghua.edu.cn/simple
# Copyright (c) Alibaba, Inc. and its affiliates.
# !/usr/bin/env python
import os
import shutil
from setuptools import find_packages, setup
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
version_file = 'swift/version.py'
def get_version():
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def parse_requirements(fname='requirements.txt', with_version=True):
"""
Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
fname (str): path to requirements file
with_version (bool, default=False): if True include version specs
Returns:
List[str]: list of requirements items
CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
import re
import sys
from os.path import exists
require_fpath = fname
def parse_line(line):
"""
Parse information from a line in a requirements text file
"""
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
relative_base = os.path.dirname(fname)
absolute_target = os.path.join(relative_base, target)
for info in parse_require_file(absolute_target):
yield info
else:
info = {'line': line}
if line.startswith('-e '):
info['package'] = line.split('#egg=')[1]
else:
# Remove versioning from the package
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]
info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ';' in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip, rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
yield info
def parse_require_file(fpath):
with open(fpath, 'r', encoding='utf-8') as f:
for line in f.readlines():
line = line.strip()
if line.startswith('http'):
print('skip http requirements %s' % line)
continue
if line and not line.startswith('#') and not line.startswith('--'):
for info in parse_line(line):
yield info
elif line and line.startswith('--find-links'):
eles = line.split()
for e in eles:
e = e.strip()
if 'http' in e:
info = dict(dependency_links=e)
yield info
def gen_packages_items():
items = []
deps_link = []
if exists(require_fpath):
for info in parse_require_file(require_fpath):
if 'dependency_links' not in info:
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
items.append(item)
else:
deps_link.append(info['dependency_links'])
return items, deps_link
return gen_packages_items()
if __name__ == '__main__':
install_requires, deps_link = parse_requirements('requirements.txt')
extra_requires = {}
all_requires = []
extra_requires['llm'], _ = parse_requirements('requirements/llm.txt')
extra_requires['aigc'], _ = parse_requirements('requirements/aigc.txt')
extra_requires['eval'], _ = parse_requirements('requirements/eval.txt')
extra_requires['seq_parallel'], _ = parse_requirements('requirements/seq_parallel.txt')
all_requires.extend(install_requires)
all_requires.extend(extra_requires['llm'])
all_requires.extend(extra_requires['aigc'])
all_requires.extend(extra_requires['eval'])
all_requires.extend(extra_requires['seq_parallel'])
extra_requires['seq_parallel'].extend(extra_requires['llm'])
extra_requires['all'] = all_requires
setup(
name='ms-swift',
version=get_version(),
description='Swift: Scalable lightWeight Infrastructure for Fine-Tuning',
long_description=readme(),
long_description_content_type='text/markdown',
author='DAMO ModelScope teams',
author_email='contact@modelscope.cn',
keywords='python, petl, efficient tuners',
url='https://github.com/modelscope/swift',
packages=find_packages(exclude=('configs', 'demo')),
include_package_data=True,
package_data={
'': ['*.h', '*.cpp', '*.cu'],
},
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
],
license='Apache License 2.0',
tests_require=parse_requirements('requirements/tests.txt'),
install_requires=install_requires,
extras_require=extra_requires,
entry_points={'console_scripts': ['swift=swift.cli.main:cli_main']},
dependency_links=deps_link,
zip_safe=False)
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from .utils.import_utils import _LazyModule
if TYPE_CHECKING:
from .version import __version__, __release_datetime__
from .tuners import (Adapter, AdapterConfig, AdapterModule, SwiftModel, LoRA, LoRAConfig, SWIFT_MAPPING,
AdaLoraConfig, IA3Config, LoftQConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig,
PeftConfig, PeftModel, PeftModelForCausalLM, ResTuningConfig, SideConfig,
PeftModelForSeq2SeqLM, PeftModelForSequenceClassification, PeftModelForTokenClassification,
PrefixTuningConfig, PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig,
get_peft_config, get_peft_model, get_peft_model_state_dict, Prompt, PromptConfig, PromptModule,
SwiftConfig, SwiftOutput, Swift, SwiftTuners, LongLoRAConfig, LongLoRA, LongLoRAModelType,
SCETuning, SCETuningConfig)
from .hub import snapshot_download, push_to_hub, push_to_hub_async, push_to_hub_in_queue
from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy,
SchedulerType, ShardedDDPOption, TrainingArguments, Seq2SeqTrainingArguments, Trainer,
Seq2SeqTrainer)
from .utils import get_logger
else:
_import_structure = {
'version': ['__release_datetime__', '__version__'],
'hub': ['snapshot_download', 'push_to_hub', 'push_to_hub_async', 'push_to_hub_in_queue'],
'tuners': [
'Adapter', 'AdapterConfig', 'AdapterModule', 'SwiftModel', 'LoRA', 'LoRAConfig', 'SWIFT_MAPPING',
'LoraConfig', 'AdaLoraConfig', 'IA3Config', 'LoftQConfig', 'LoHaConfig', 'LoKrConfig', 'OFTConfig',
'PeftConfig', 'ResTuningConfig', 'SideConfig', 'PeftModel', 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM',
'PeftModelForSequenceClassification', 'PeftModelForTokenClassification', 'PrefixTuningConfig',
'PromptEncoderConfig', 'PromptLearningConfig', 'PromptTuningConfig', 'get_peft_config', 'get_peft_model',
'get_peft_model_state_dict', 'Prompt', 'PromptConfig', 'PromptModule', 'SwiftConfig', 'SwiftOutput',
'Swift', 'SwiftTuners', 'LongLoRAConfig', 'LongLoRA', 'LongLoRAModelType', 'SCETuning', 'SCETuningConfig'
],
'trainers': [
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend', 'HubStrategy', 'IntervalStrategy', 'SchedulerType',
'ShardedDDPOption', 'TrainingArguments', 'Seq2SeqTrainingArguments', 'Trainer', 'Seq2SeqTrainer'
],
'utils': ['get_logger']
}
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from swift.utils.import_utils import _LazyModule
if TYPE_CHECKING:
# Recommend using `xxx_main`
from .animatediff import animatediff_sft, animatediff_main
from .animatediff_infer import animatediff_infer, animatediff_infer_main
from .diffusers import train_text_to_image, train_text_to_image_lora, train_text_to_image_lora_sdxl, \
train_text_to_image_sdxl, infer_text_to_image, infer_text_to_image_lora, infer_text_to_image_sdxl, \
infer_text_to_image_lora_sdxl, train_controlnet, train_controlnet_sdxl, train_dreambooth, \
train_dreambooth_lora, train_dreambooth_lora_sdxl, infer_controlnet, infer_controlnet_sdxl, \
infer_dreambooth, infer_dreambooth_lora, infer_dreambooth_lora_sdxl
from .utils import AnimateDiffArguments, AnimateDiffInferArguments
else:
_import_structure = {
'animatediff': ['animatediff_sft', 'animatediff_main'],
'animatediff_infer': ['animatediff_infer', 'animatediff_infer_main'],
'diffusers': [
'train_text_to_image', 'train_text_to_image_lora', 'train_text_to_image_lora_sdxl',
'train_text_to_image_sdxl', 'infer_text_to_image', 'infer_text_to_image_lora', 'infer_text_to_image_sdxl',
'infer_text_to_image_lora_sdxl', 'train_controlnet', 'train_controlnet_sdxl', 'train_dreambooth',
'train_dreambooth_lora', 'train_dreambooth_lora_sdxl', 'infer_controlnet', 'infer_controlnet_sdxl',
'infer_dreambooth', 'infer_dreambooth_lora', 'infer_dreambooth_lora_sdxl'
],
'utils': ['AnimateDiffArguments', 'AnimateDiffInferArguments'],
}
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)
import csv
import datetime
import inspect
import logging
import os
import random
import re
from copy import deepcopy
from types import MethodType
from typing import Dict
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from decord import VideoReader
from diffusers import AutoencoderKL, DDIMScheduler, MotionAdapter, UNet2DConditionModel, UNetMotionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines import AnimateDiffPipeline
from diffusers.utils import export_to_gif
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
from modelscope import snapshot_download
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import RandomSampler
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from swift import LoRAConfig, Swift, get_logger, push_to_hub
from swift.aigc.utils import AnimateDiffArguments
from swift.utils import get_dist_setting, get_main, is_dist
logger = get_logger()
class AnimateDiffDataset(Dataset):
VIDEO_ID = 'videoid'
NAME = 'name'
CONTENT_URL = 'contentUrl'
def __init__(
self,
csv_path,
video_folder,
sample_size=256,
sample_stride=4,
sample_n_frames=16,
dataset_sample_size=10000,
):
print(f'loading annotations from {csv_path} ...')
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
dataset = []
for d in tqdm(self.dataset):
content_url = d[self.CONTENT_URL]
file_name = content_url.split('/')[-1]
if os.path.isfile(os.path.join(video_folder, file_name)):
dataset.append(d)
if dataset_sample_size is not None and len(dataset) > dataset_sample_size:
break
self.dataset = dataset
self.length = len(self.dataset)
print(f'data scale: {self.length}')
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict: Dict[str, str] = self.dataset[idx]
name = video_dict[self.NAME]
content_url = video_dict[self.CONTENT_URL]
file_name = content_url.split('/')[-1]
video_dir = os.path.join(self.video_folder, file_name)
video_reader = VideoReader(video_dir)
video_length = len(video_reader)
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
return pixel_values, name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values, name = self.get_batch(idx)
break
except Exception as e:
logger.error(f'Error loading dataset batch: {e}')
idx = random.randint(0, self.length - 1)
pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, duration=4):
import imageio
videos = rearrange(videos, 'b c t h w -> t b c h w')
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, duration=duration)
def animatediff_sft(args: AnimateDiffArguments) -> None:
# Initialize distributed training
if is_dist():
_, local_rank, num_processes, _ = get_dist_setting()
global_rank = dist.get_rank()
else:
local_rank = 0
global_rank = 0
num_processes = 1
is_main_process = global_rank == 0
global_seed = args.seed + global_rank
torch.manual_seed(global_seed)
# Logging folder
folder_name = datetime.datetime.now().strftime('ad-%Y-%m-%dT%H-%M-%S')
output_dir = os.path.join(args.output_dir, folder_name)
*_, config = inspect.getargvalues(inspect.currentframe())
if is_main_process and args.use_wandb:
import wandb
wandb.init(project='animatediff', name=folder_name, config=config)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f'{output_dir}/samples', exist_ok=True)
os.makedirs(f'{output_dir}/sanity_check', exist_ok=True)
os.makedirs(f'{output_dir}/checkpoints', exist_ok=True)
with open(args.validation_prompts_path, 'r') as f:
validation_data = f.readlines()
# Load scheduler, tokenizer and models.
noise_scheduler = DDIMScheduler(
num_train_timesteps=args.num_train_timesteps,
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
steps_offset=args.steps_offset,
clip_sample=args.clip_sample,
)
if not os.path.exists(args.model_id_or_path):
pretrained_model_path = snapshot_download(args.model_id_or_path, revision=args.model_revision)
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder='vae')
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder='text_encoder')
motion_adapter = None
if args.motion_adapter_id_or_path is not None:
if not os.path.exists(args.motion_adapter_id_or_path):
args.motion_adapter_id_or_path = snapshot_download(
args.motion_adapter_id_or_path, revision=args.motion_adapter_revision)
motion_adapter = MotionAdapter.from_pretrained(args.motion_adapter_id_or_path)
unet: UNetMotionModel = UNetMotionModel.from_unet2d(
UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder='unet'),
motion_adapter=motion_adapter,
load_weights=True,
)
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Set unet trainable parameters
unet.requires_grad_(False)
for name, param in unet.named_parameters():
if re.fullmatch(args.trainable_modules, name):
param.requires_grad = True
# Preparing LoRA
if args.sft_type == 'lora':
if args.motion_adapter_id_or_path is None:
raise ValueError('No AnimateDiff weight found, Please do not use LoRA.')
lora_config = LoRAConfig(
r=args.lora_rank,
target_modules=args.trainable_modules,
lora_alpha=args.lora_alpha,
lora_dtype=args.lora_dtype,
lora_dropout=args.lora_dropout_p)
unet = Swift.prepare_model(unet, lora_config)
logger.info(f'lora_config: {lora_config}')
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
optimizer = torch.optim.AdamW(
trainable_params,
lr=args.learning_rate,
weight_decay=args.weight_decay,
)
if is_main_process:
print(f'trainable params number: {len(trainable_params)}')
print(f'trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M')
# Enable xformers
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError('xformers is not available. Make sure it is installed correctly')
# Enable gradient checkpointing
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Move models to GPU
vae.to(local_rank)
text_encoder.to(local_rank)
# Get the training dataset
train_dataset = AnimateDiffDataset(
csv_path=args.csv_path,
video_folder=args.video_folder,
sample_size=args.sample_size,
sample_stride=args.sample_stride,
sample_n_frames=args.sample_n_frames,
dataset_sample_size=args.dataset_sample_size,
)
if not is_dist():
sampler = RandomSampler(train_dataset)
else:
sampler = DistributedSampler(
train_dataset, num_replicas=num_processes, rank=global_rank, shuffle=True, seed=global_seed)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=False,
sampler=sampler,
num_workers=args.dataloader_num_workers,
pin_memory=True,
drop_last=True,
)
# Get the training iteration
max_train_steps = args.num_train_epochs * len(train_dataloader)
print(f'max_train_steps: {max_train_steps}')
# Scheduler
lr_scheduler = get_scheduler(
args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=int(args.warmup_ratio * max_train_steps) // args.gradient_accumulation_steps,
num_training_steps=max_train_steps // args.gradient_accumulation_steps,
)
unet.to(local_rank)
if is_dist():
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
num_train_epochs = args.num_train_epochs
# Train!
total_batch_size = args.batch_size * num_processes * args.gradient_accumulation_steps
if is_main_process:
logging.info('***** Running training *****')
logging.info(f' Num examples = {len(train_dataset)}')
logging.info(f' Num Epochs = {num_train_epochs}')
logging.info(f' Instantaneous batch size per device = {args.batch_size}')
logging.info(f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}')
logging.info(f' Gradient Accumulation steps = {args.gradient_accumulation_steps}')
logging.info(f' Total optimization steps = {max_train_steps}')
global_step = 0
first_epoch = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
progress_bar.set_description('Steps')
# Support mixed-precision training
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
for epoch in range(first_epoch, num_train_epochs):
if is_dist():
train_dataloader.sampler.set_epoch(epoch)
unet.train()
for step, batch in enumerate(train_dataloader):
if args.text_dropout_rate > 0:
batch['text'] = [name if random.random() > args.text_dropout_rate else '' for name in batch['text']]
# Data batch sanity check
if epoch == first_epoch and step == 0:
pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
pixel_values = rearrange(pixel_values, 'b f c h w -> b c f h w')
for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
pixel_value = pixel_value[None, ...]
file_name = '-'.join(text.replace('/',
'').split()[:10]) if not text == '' else f'{global_rank}-{idx}'
save_videos_grid(pixel_value, f'{output_dir}/sanity_check/{file_name}.gif', rescale=True)
# Convert videos to latent space
pixel_values = batch['pixel_values'].to(local_rank)
video_length = pixel_values.shape[1]
with torch.no_grad():
pixel_values = rearrange(pixel_values, 'b f c h w -> (b f) c h w')
latents = vae.encode(pixel_values).latent_dist
latents = latents.sample()
latents = rearrange(latents, '(b f) c h w -> b c f h w', f=video_length)
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz, ), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
with torch.no_grad():
prompt_ids = tokenizer(
batch['text'],
max_length=tokenizer.model_max_length,
padding='max_length',
truncation=True,
return_tensors='pt').input_ids.to(latents.device)
encoder_hidden_states = text_encoder(prompt_ids)[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == 'epsilon':
target = noise
elif noise_scheduler.config.prediction_type == 'v_prediction':
raise NotImplementedError
else:
raise ValueError(f'Unknown prediction type {noise_scheduler.config.prediction_type}')
# Predict the noise residual and compute loss
# Mixed-precision training
with torch.cuda.amp.autocast(enabled=args.mixed_precision):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean')
# Backpropagate
if args.mixed_precision:
scaler.scale(loss).backward()
else:
loss.backward()
if step % args.gradient_accumulation_steps == 0:
# Backpropagate
if args.mixed_precision:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
progress_bar.update(1)
global_step += 1
# Wandb logging
if is_main_process and args.use_wandb:
wandb.log({'train_loss': loss.item()}, step=global_step)
# Save checkpoint
if is_main_process and (global_step % args.save_steps == 0 or step == len(train_dataloader) - 1):
save_path = os.path.join(output_dir, 'checkpoints')
if step == len(train_dataloader) - 1:
if isinstance(unet, DDP):
unet.module.save_pretrained(os.path.join(save_path, 'iter-last'))
else:
unet.save_pretrained(os.path.join(save_path, 'iter-last'))
if args.push_to_hub:
push_to_hub(
repo_name=args.hub_model_id,
output_dir=os.path.join(save_path, 'iter-last'),
token=args.hub_token,
private=True,
)
logging.info(f'Saved state to {os.path.join(save_path, "iter-last")} on the last step')
else:
iter_save_path = os.path.join(save_path, f'iter-{global_step}')
if isinstance(unet, DDP):
unet.module.save_pretrained(iter_save_path)
else:
unet.save_pretrained(iter_save_path)
if args.push_to_hub and args.push_hub_strategy == 'all_checkpoints':
push_to_hub(
repo_name=args.hub_model_id,
output_dir=os.path.join(save_path, f'iter-{global_step}'),
token=args.hub_token,
private=True,
)
logging.info(
f'Saved state to {os.path.join(save_path, f"iter-{global_step}")} (global_step: {global_step})')
# Periodically validation
if is_main_process and global_step % args.eval_steps == 0:
generator = torch.Generator(device=latents.device)
generator.manual_seed(global_seed)
Swift.merge(unet)
height = args.sample_size
width = args.sample_size
def state_dict(self,
*args,
destination=None,
prefix='',
keep_vars=False,
adapter_name: str = None,
**kwargs):
state_dict = self.state_dict_origin()
return {
key.replace('base_layer.', ''): value
for key, value in state_dict.items() if 'lora' not in key
}
motion_adapter = MotionAdapter(
motion_num_attention_heads=args.motion_num_attention_heads,
motion_max_seq_length=args.motion_max_seq_length)
module = unet if not isinstance(unet, DDP) else unet.module
motion_adapter.mid_block.motion_modules = deepcopy(module.mid_block.motion_modules)
motion_adapter.mid_block.motion_modules.state_dict_origin = \
motion_adapter.mid_block.motion_modules.state_dict
motion_adapter.mid_block.motion_modules.state_dict = MethodType(state_dict,
motion_adapter.mid_block.motion_modules)
for db1, db2 in zip(motion_adapter.down_blocks, module.down_blocks):
db1.motion_modules = deepcopy(db2.motion_modules)
db1.motion_modules.state_dict_origin = db1.motion_modules.state_dict
db1.motion_modules.state_dict = MethodType(state_dict, db1.motion_modules)
for db1, db2 in zip(motion_adapter.up_blocks, module.up_blocks):
db1.motion_modules = deepcopy(db2.motion_modules)
db1.motion_modules.state_dict_origin = db1.motion_modules.state_dict
db1.motion_modules.state_dict = MethodType(state_dict, db1.motion_modules)
Swift.unmerge(unet)
validation_pipeline = AnimateDiffPipeline(
unet=UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder='unet'),
vae=vae,
tokenizer=tokenizer,
motion_adapter=motion_adapter,
text_encoder=text_encoder,
scheduler=noise_scheduler,
).to('cuda')
validation_pipeline.enable_vae_slicing()
validation_pipeline.enable_model_cpu_offload()
for idx, prompt in enumerate(validation_data):
output = validation_pipeline(
prompt=prompt,
negative_prompt='bad quality, worse quality',
num_frames=args.sample_n_frames,
height=height,
width=width,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
generator=torch.Generator('cpu').manual_seed(global_seed),
)
frames = output.frames[0]
export_to_gif(frames, f'{output_dir}/samples/sample-{global_step}-{idx}.gif')
unet.train()
logs = {'step_loss': loss.detach().item(), 'lr': lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
if is_dist():
dist.destroy_process_group()
animatediff_main = get_main(AnimateDiffArguments, animatediff_sft)
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import torch
from diffusers import DDIMScheduler, MotionAdapter
from diffusers.pipelines import AnimateDiffPipeline
from diffusers.utils import export_to_gif
from swift import Swift, snapshot_download
from swift.aigc.utils import AnimateDiffInferArguments
from swift.utils import get_logger, get_main
logger = get_logger()
def animatediff_infer(args: AnimateDiffInferArguments) -> None:
generator = torch.Generator(device='cpu')
generator.manual_seed(args.seed)
# Load scheduler, tokenizer and models.
noise_scheduler = DDIMScheduler(
num_train_timesteps=args.num_train_timesteps,
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
steps_offset=args.steps_offset,
clip_sample=args.clip_sample,
)
if not os.path.exists(args.model_id_or_path):
pretrained_model_path = snapshot_download(args.model_id_or_path, revision=args.model_revision)
else:
pretrained_model_path = args.model_id_or_path
motion_adapter = None
if args.motion_adapter_id_or_path is not None:
if not os.path.exists(args.motion_adapter_id_or_path):
args.motion_adapter_id_or_path = snapshot_download(
args.motion_adapter_id_or_path, revision=args.motion_adapter_revision)
motion_adapter = MotionAdapter.from_pretrained(args.motion_adapter_id_or_path)
if args.sft_type == 'full':
motion_adapter_dir = args.ckpt_dir if args.ckpt_dir is not None else os.path.join(
pretrained_model_path, 'motion_adapter')
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_dir)
validation_pipeline = AnimateDiffPipeline.from_pretrained(
pretrained_model_path,
motion_adapter=motion_adapter,
).to('cuda')
validation_pipeline.scheduler = noise_scheduler
if not args.sft_type == 'full':
model = Swift.from_pretrained(validation_pipeline.unet, args.ckpt_dir)
if args.merge_lora:
ckpt_dir, ckpt_name = os.path.split(args.ckpt_dir)
merged_lora_path = os.path.join(ckpt_dir, f'{ckpt_name}-merged')
logger.info(f'merged_lora_path: `{merged_lora_path}`')
logger.info("Setting args.sft_type: 'full'")
logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
args.sft_type = 'full'
args.ckpt_dir = merged_lora_path
if os.path.exists(args.ckpt_dir) and not args.replace_if_exists:
logger.warn(f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, '
'skipping the saving process. '
'you can pass `replace_if_exists=True` to overwrite it.')
return
Swift.merge_and_unload(model)
validation_pipeline.unet = model.model
validation_pipeline.save_pretrained(args.ckpt_dir)
validation_pipeline.enable_vae_slicing()
validation_pipeline.enable_model_cpu_offload()
if args.eval_human:
idx = 0
while True:
prompt = input('<<< ')
sample = validation_pipeline(
prompt,
negative_prompt='bad quality, worse quality',
generator=generator,
num_frames=args.sample_n_frames,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
).frames[0]
os.makedirs(args.output_path, exist_ok=True)
logger.info(f'Output saved to: {f"{args.output_path}/output-{idx}.gif"}')
export_to_gif(sample, f'{args.output_path}/output-{idx}.gif')
idx += 1
else:
with open(args.validation_prompts_path, 'r') as f:
validation_data = f.readlines()
for idx, prompt in enumerate(validation_data):
sample = validation_pipeline(
prompt,
negative_prompt='bad quality, worse quality',
generator=generator,
num_frames=args.sample_n_frames,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
).frames[0]
os.makedirs(args.output_path, exist_ok=True)
logger.info(f'Output saved to: {f"{args.output_path}/output-{idx}.gif"}')
export_to_gif(sample, f'{args.output_path}/output-{idx}.gif')
animatediff_infer_main = get_main(AnimateDiffInferArguments, animatediff_infer)
from .infer_controlnet import main as infer_controlnet
from .infer_controlnet_sdxl import main as infer_controlnet_sdxl
from .infer_dreambooth import main as infer_dreambooth
from .infer_dreambooth_lora import main as infer_dreambooth_lora
from .infer_dreambooth_lora_sdxl import main as infer_dreambooth_lora_sdxl
from .infer_text_to_image import main as infer_text_to_image
from .infer_text_to_image_lora import main as infer_text_to_image_lora
from .infer_text_to_image_lora_sdxl import main as infer_text_to_image_lora_sdxl
from .infer_text_to_image_sdxl import main as infer_text_to_image_sdxl
from .train_controlnet import main as train_controlnet
from .train_controlnet_sdxl import main as train_controlnet_sdxl
from .train_dreambooth import main as train_dreambooth
from .train_dreambooth_lora import main as train_dreambooth_lora
from .train_dreambooth_lora_sdxl import main as train_dreambooth_lora_sdxl
from .train_text_to_image import main as train_text_to_image
from .train_text_to_image_lora import main as train_text_to_image_lora
from .train_text_to_image_lora_sdxl import main as train_text_to_image_lora_sdxl
from .train_text_to_image_sdxl import main as train_text_to_image_sdxl
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
from diffusers.utils import load_image
from modelscope import snapshot_download
def parse_args():
parser = argparse.ArgumentParser(description='Simple example of a ControlNet inference.')
parser.add_argument(
'--base_model_path',
type=str,
default='AI-ModelScope/stable-diffusion-v1-5',
required=True,
help='Path to pretrained model or model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--revision',
type=str,
default=None,
required=False,
help='Revision of pretrained model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--controlnet_path',
type=str,
default=None,
required=False,
help='The path to trained controlnet model.',
)
parser.add_argument(
'--prompt',
type=str,
default=None,
required=True,
help='The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
)
parser.add_argument(
'--control_image_path',
type=str,
default=None,
required=True,
help='The path to conditioning image.',
)
parser.add_argument(
'--image_save_path',
type=str,
default=None,
required=True,
help='The path to save generated image',
)
parser.add_argument(
'--torch_dtype',
type=str,
default=None,
choices=['no', 'fp16', 'bf16'],
help=('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
' 1.10.and an Nvidia Ampere GPU. Default to the value of the'
' mixed_precision passed with the `accelerate.launch` command in training script.'),
)
parser.add_argument('--seed', type=int, default=None, help='A seed for inference.')
parser.add_argument(
'--num_inference_steps',
type=int,
default=20,
help=('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
expense of slower inference.'),
)
parser.add_argument(
'--guidance_scale',
type=float,
default=7.5,
help=('A higher guidance scale value encourages the model to generate images closely linked to the text \
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'),
)
args = parser.parse_args()
return args
def main():
args = parse_args()
if os.path.exists(args.base_model_path):
base_model_path = args.base_model_path
else:
base_model_path = snapshot_download(args.base_model_path, revision=args.revision)
if args.torch_dtype == 'fp16':
torch_dtype = torch.float16
elif args.torch_dtype == 'bf16':
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
controlnet = ControlNetModel.from_pretrained(args.controlnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch_dtype)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# memory optimization.
pipe.enable_model_cpu_offload()
control_image = load_image(args.control_image_path)
# generate image
generator = torch.manual_seed(args.seed)
image = pipe(
args.prompt, num_inference_steps=args.num_inference_steps, generator=generator, image=control_image).images[0]
image.save(args.image_save_path)
import argparse
import os
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, UniPCMultistepScheduler
from diffusers.utils import load_image
from modelscope import snapshot_download
def parse_args():
parser = argparse.ArgumentParser(description='Simple example of a ControlNet inference.')
parser.add_argument(
'--base_model_path',
type=str,
default='AI-ModelScope/stable-diffusion-xl-base-1.0',
required=True,
help='Path to pretrained model or model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--revision',
type=str,
default=None,
required=False,
help='Revision of pretrained model identifier from modelscope.cn/models.',
)
parser.add_argument(
'--controlnet_path',
type=str,
default=None,
required=False,
help='The path to trained controlnet model.',
)
parser.add_argument(
'--prompt',
type=str,
default=None,
required=True,
help='The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
)
parser.add_argument(
'--control_image_path',
type=str,
default=None,
required=True,
help='The path to conditioning image.',
)
parser.add_argument(
'--image_save_path',
type=str,
default=None,
required=True,
help='The path to save generated image',
)
parser.add_argument(
'--torch_dtype',
type=str,
default=None,
choices=['no', 'fp16', 'bf16'],
help=('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
' 1.10.and an Nvidia Ampere GPU. Default to the value of the'
' mixed_precision passed with the `accelerate.launch` command in training script.'),
)
parser.add_argument('--seed', type=int, default=None, help='A seed for inference.')
parser.add_argument(
'--num_inference_steps',
type=int,
default=20,
help=('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
expense of slower inference.'),
)
parser.add_argument(
'--guidance_scale',
type=float,
default=7.5,
help=('A higher guidance scale value encourages the model to generate images closely linked to the text \
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'),
)
args = parser.parse_args()
return args
def main():
args = parse_args()
if os.path.exists(args.base_model_path):
base_model_path = args.base_model_path
else:
base_model_path = snapshot_download(args.base_model_path, revision=args.revision)
if args.torch_dtype == 'fp16':
torch_dtype = torch.float16
elif args.torch_dtype == 'bf16':
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
controlnet = ControlNetModel.from_pretrained(args.controlnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch_dtype)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# memory optimization.
pipe.enable_model_cpu_offload()
control_image = load_image(args.control_image_path)
# generate image
generator = torch.manual_seed(args.seed)
image = pipe(
args.prompt, num_inference_steps=args.num_inference_steps, generator=generator, image=control_image).images[0]
image.save(args.image_save_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment