Commit cad1f505 authored by Fazzie's avatar Fazzie
Browse files

fix ckpt

parent 6e0faa70
...@@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la ...@@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
``` ```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark pip install transformers==4.19.2 diffusers invisible-watermark
pip install -e .
``` ```
#### Step 2: install lightning #### Step 2: install lightning
Install Lightning version later than 2022.01.04. We suggest you install lightning from source. Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
##### From Source
``` ```
git clone https://github.com/Lightning-AI/lightning.git git clone https://github.com/Lightning-AI/lightning.git
pip install -r requirements.txt pip install -r requirements.txt
python setup.py install python setup.py install
``` ```
##### From pip
```
pip install pytorch-lightning
```
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website #### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
##### From pip ##### From pip
For example, you can install v0.1.12 from our official website. For example, you can install v0.2.0 from our official website.
``` ```
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
``` ```
##### From source ##### From source
...@@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes ...@@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. 3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
## Download the model checkpoint from pretrained ## Download the model checkpoint from pretrained
### stable-diffusion-v2-base ### stable-diffusion-v2-base(Recommand)
``` ```
wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
...@@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512 ...@@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512
### stable-diffusion-v1-4 ### stable-diffusion-v1-4
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
``` ```
git lfs install git lfs install
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
...@@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 ...@@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
### stable-diffusion-v1-5 from runway ### stable-diffusion-v1-5 from runway
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
``` ```
git lfs install git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
...@@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss ...@@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss
and can also use `train_ddp.sh` to run the training task with ddp to compare. and can also use `train_ddp.sh` to run the training task with ddp to compare.
In `train_colossalai.sh` the main command is: In `train_colossalai.sh` the main command is:
``` ```
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt
``` ```
- you can change the `--logdir` to decide where to save the log information and the last checkpoint. - You can change the `--logdir` to decide where to save the log information and the last checkpoint.
- You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints`
- You will find your train config yaml in `logdir/configs`
- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt`
- You can change the `--base` to specify the path of config yaml
### Training config ### Training config
...@@ -186,7 +192,8 @@ You can change the trainging config in the yaml file ...@@ -186,7 +192,8 @@ You can change the trainging config in the yaml file
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai - precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai) - more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
## Finetune Example (Work In Progress)
## Finetune Example
### Training on Teyvat Datasets ### Training on Teyvat Datasets
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions. We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
...@@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an ...@@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an
``` ```
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
--outdir ./output \ --outdir ./output \
--config path/to/logdir/checkpoints/last.ckpt \ --ckpt path/to/logdir/checkpoints/last.ckpt \
--ckpt /path/to/logdir/configs/project.yaml \ --config /path/to/logdir/configs/project.yaml \
``` ```
```commandline ```commandline
......
...@@ -6,6 +6,7 @@ model: ...@@ -6,6 +6,7 @@ model:
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
ckpt: None # use ckpt path
log_every_t: 200 log_every_t: 200
timesteps: 1000 timesteps: 1000
first_stage_key: image first_stage_key: image
...@@ -16,7 +17,7 @@ model: ...@@ -16,7 +17,7 @@ model:
conditioning_key: crossattn conditioning_key: crossattn
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config use_ema: False
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler target: ldm.lr_scheduler.LambdaLinearScheduler
......
...@@ -106,7 +106,20 @@ def get_parser(**parser_kwargs): ...@@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
nargs="?", nargs="?",
help="disable test", help="disable test",
) )
parser.add_argument("-p", "--project", help="name of new or path to existing project") parser.add_argument(
"-p",
"--project",
help="name of new or path to existing project",
)
parser.add_argument(
"-c",
"--ckpt",
type=str,
const=True,
default="",
nargs="?",
help="load pretrained checkpoint from stable AI",
)
parser.add_argument( parser.add_argument(
"-d", "-d",
"--debug", "--debug",
...@@ -145,22 +158,7 @@ def get_parser(**parser_kwargs): ...@@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
default=True, default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate", help="scale base-lr by ngpu * batch_size * n_accumulate",
) )
parser.add_argument(
"--use_fp16",
type=str2bool,
nargs="?",
const=True,
default=True,
help="whether to use fp16",
)
parser.add_argument(
"--flash",
type=str2bool,
const=True,
default=False,
nargs="?",
help="whether to use flash attention",
)
return parser return parser
...@@ -341,6 +339,12 @@ class SetupCallback(Callback): ...@@ -341,6 +339,12 @@ class SetupCallback(Callback):
except FileNotFoundError: except FileNotFoundError:
pass pass
# def on_fit_end(self, trainer, pl_module):
# if trainer.global_rank == 0:
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
# trainer.save_checkpoint(ckpt_path)
class ImageLogger(Callback): class ImageLogger(Callback):
...@@ -536,6 +540,7 @@ if __name__ == "__main__": ...@@ -536,6 +540,7 @@ if __name__ == "__main__":
"If you want to resume training in a new log folder, " "If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint") "use -n/--name in combination with --resume_from_checkpoint")
if opt.resume: if opt.resume:
rank_zero_info("Resuming from {}".format(opt.resume))
if not os.path.exists(opt.resume): if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume)) raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume): if os.path.isfile(opt.resume):
...@@ -543,13 +548,13 @@ if __name__ == "__main__": ...@@ -543,13 +548,13 @@ if __name__ == "__main__":
# idx = len(paths)-paths[::-1].index("logs")+1 # idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx]) # logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2]) logdir = "/".join(paths[:-2])
rank_zero_info("logdir: {}".format(logdir))
ckpt = opt.resume ckpt = opt.resume
else: else:
assert os.path.isdir(opt.resume), opt.resume assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/") logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base opt.base = base_configs + opt.base
_tmp = logdir.split("/") _tmp = logdir.split("/")
...@@ -558,6 +563,7 @@ if __name__ == "__main__": ...@@ -558,6 +563,7 @@ if __name__ == "__main__":
if opt.name: if opt.name:
name = "_" + opt.name name = "_" + opt.name
elif opt.base: elif opt.base:
rank_zero_info("Using base config {}".format(opt.base))
cfg_fname = os.path.split(opt.base[0])[-1] cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0] cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name name = "_" + cfg_name
...@@ -566,6 +572,9 @@ if __name__ == "__main__": ...@@ -566,6 +572,9 @@ if __name__ == "__main__":
nowname = now + name + opt.postfix nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname) logdir = os.path.join(opt.logdir, nowname)
if opt.ckpt:
ckpt = opt.ckpt
ckptdir = os.path.join(logdir, "checkpoints") ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs") cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed) seed_everything(opt.seed)
...@@ -582,14 +591,11 @@ if __name__ == "__main__": ...@@ -582,14 +591,11 @@ if __name__ == "__main__":
for k in nondefault_trainer_args(opt): for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k) trainer_config[k] = getattr(opt, k)
print(trainer_config)
if not trainer_config["accelerator"] == "gpu": if not trainer_config["accelerator"] == "gpu":
del trainer_config["accelerator"] del trainer_config["accelerator"]
cpu = True cpu = True
print("Running on CPU")
else: else:
cpu = False cpu = False
print("Running on GPU")
trainer_opt = argparse.Namespace(**trainer_config) trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config lightning_config.trainer = trainer_config
...@@ -597,10 +603,12 @@ if __name__ == "__main__": ...@@ -597,10 +603,12 @@ if __name__ == "__main__":
use_fp16 = trainer_config.get("precision", 32) == 16 use_fp16 = trainer_config.get("precision", 32) == 16
if use_fp16: if use_fp16:
config.model["params"].update({"use_fp16": True}) config.model["params"].update({"use_fp16": True})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
else: else:
config.model["params"].update({"use_fp16": False}) config.model["params"].update({"use_fp16": False})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
if ckpt is not None:
config.model["params"].update({"ckpt": ckpt})
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
# trainer and callbacks # trainer and callbacks
...@@ -639,7 +647,6 @@ if __name__ == "__main__": ...@@ -639,7 +647,6 @@ if __name__ == "__main__":
# config the strategy, defualt is ddp # config the strategy, defualt is ddp
if "strategy" in trainer_config: if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"] strategy_cfg = trainer_config["strategy"]
print("Using strategy: {}".format(strategy_cfg["target"]))
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
else: else:
strategy_cfg = { strategy_cfg = {
...@@ -648,7 +655,6 @@ if __name__ == "__main__": ...@@ -648,7 +655,6 @@ if __name__ == "__main__":
"find_unused_parameters": False "find_unused_parameters": False
} }
} }
print("Using strategy: DDPStrategy")
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
...@@ -664,7 +670,6 @@ if __name__ == "__main__": ...@@ -664,7 +670,6 @@ if __name__ == "__main__":
} }
} }
if hasattr(model, "monitor"): if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3 default_modelckpt_cfg["params"]["save_top_k"] = 3
...@@ -673,7 +678,6 @@ if __name__ == "__main__": ...@@ -673,7 +678,6 @@ if __name__ == "__main__":
else: else:
modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'): if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
...@@ -710,8 +714,6 @@ if __name__ == "__main__": ...@@ -710,8 +714,6 @@ if __name__ == "__main__":
"target": "main.CUDACallback" "target": "main.CUDACallback"
}, },
} }
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
if "callbacks" in lightning_config: if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks callbacks_cfg = lightning_config.callbacks
...@@ -737,15 +739,11 @@ if __name__ == "__main__": ...@@ -737,15 +739,11 @@ if __name__ == "__main__":
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ### trainer.logdir = logdir
# data # data
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)
...@@ -754,9 +752,9 @@ if __name__ == "__main__": ...@@ -754,9 +752,9 @@ if __name__ == "__main__":
# lightning still takes care of proper multiprocessing though # lightning still takes care of proper multiprocessing though
data.prepare_data() data.prepare_data()
data.setup() data.setup()
print("#### Data #####")
for k in data.datasets: for k in data.datasets:
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate # configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
...@@ -768,17 +766,17 @@ if __name__ == "__main__": ...@@ -768,17 +766,17 @@ if __name__ == "__main__":
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else: else:
accumulate_grad_batches = 1 accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}") rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr: if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print( rank_zero_info(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else: else:
model.learning_rate = base_lr model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++") rank_zero_info("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}") rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1 # allow checkpointing via USR1
def melk(*args, **kwargs): def melk(*args, **kwargs):
......
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \ python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \
--outdir ./output \ --outdir ./output \
--ckpt /tmp/2022-11-18T16-38-46_train_colossalai/checkpoints/last.ckpt \ --ckpt checkpoints/last.ckpt \
--config /tmp/2022-11-18T16-38-46_train_colossalai/configs/2022-11-18T16-38-46-project.yaml \ --config configs/2023-02-02T18-06-14-project.yaml \
--n_samples 4 --n_samples 4
...@@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1 ...@@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1 DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b configs/train_colossalai.yaml python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment