Unverified Commit 084c9124 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[llama] fix memory issue (#5371)

* [llama] fix memory issue

* [llama] add comment
parent eb4f2d90
...@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters ...@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
...@@ -232,10 +232,12 @@ def main() -> None: ...@@ -232,10 +232,12 @@ def main() -> None:
else nullcontext() else nullcontext()
) )
with init_ctx: with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) model = LlamaForCausalLM.from_pretrained(args.pretrained)
# Freeze part of parameters. # Freeze part of parameters.
if args.freeze_non_embeds_params: if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model) freeze_non_embeds_parameters(model=model)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint: if args.use_grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
...@@ -277,8 +279,6 @@ def main() -> None: ...@@ -277,8 +279,6 @@ def main() -> None:
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=dataloader, dataloader=dataloader,
) )
if args.load_checkpoint is None:
booster.load_model(model, args.pretrained)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment