Unverified Commit eb4f2d90 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[llama] polish training script and fix optim ckpt (#5368)

parent a5756a87
...@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters ...@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
...@@ -232,7 +232,7 @@ def main() -> None: ...@@ -232,7 +232,7 @@ def main() -> None:
else nullcontext() else nullcontext()
) )
with init_ctx: with init_ctx:
model = LlamaForCausalLM.from_pretrained(args.pretrained) model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters. # Freeze part of parameters.
if args.freeze_non_embeds_params: if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model) freeze_non_embeds_parameters(model=model)
...@@ -277,6 +277,8 @@ def main() -> None: ...@@ -277,6 +277,8 @@ def main() -> None:
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=dataloader, dataloader=dataloader,
) )
if args.load_checkpoint is None:
booster.load_model(model, args.pretrained)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
...@@ -329,7 +331,12 @@ def main() -> None: ...@@ -329,7 +331,12 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch) dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps) pbar = tqdm(
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step // args.accumulation_steps,
)
total_loss = torch.tensor(0.0, device=get_current_device()) total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader, start=start_step): for step, batch in enumerate(dataloader, start=start_step):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
...@@ -369,6 +376,7 @@ def main() -> None: ...@@ -369,6 +376,7 @@ def main() -> None:
coordinator.print_on_master("Deactivate NEFTune before saving model.") coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle) deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint( save_checkpoint(
save_dir=args.save_dir, save_dir=args.save_dir,
booster=booster, booster=booster,
......
...@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
...@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=self.tp_group, tp_group=self.tp_group,
use_zero=self.use_zero, use_zero=self.use_zero,
inplace=False, inplace=False,
device=torch.device("cuda"), device=get_current_device(),
) )
if self.pp_size == 1: if self.pp_size == 1:
...@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards. # First gather Zero shards.
if use_zero: if use_zero:
v = v.cuda() v = v.to(get_current_device())
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group) dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment