"src/vscode:/vscode.git/clone" did not exist on "adf4b173b30f463d56d111c42116e1d20e194cf4"
Commit 14b50ed0 authored by mashun1's avatar mashun1
Browse files

Update train.py

parent d6cb6066
...@@ -126,7 +126,7 @@ def train(opt): ...@@ -126,7 +126,7 @@ def train(opt):
# enable gradient checkpointing to save GPU memory, but this action would slowdown the training speed 20-30%. # enable gradient checkpointing to save GPU memory, but this action would slowdown the training speed 20-30%.
# in addition, gradient_checkpointing can not be enabled when using deepspeed ZERO-3 # in addition, gradient_checkpointing can not be enabled when using deepspeed ZERO-3
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model.enable_input_require_grads()
dataset = SFTDataset(opt.sft_data_dir, tokenizer, opt.block_size, opt.mode) dataset = SFTDataset(opt.sft_data_dir, tokenizer, opt.block_size, opt.mode)
if accelerator.is_main_process: if accelerator.is_main_process:
sanity_check(dataset[0]["input_ids"], dataset[0]["labels"], tokenizer) sanity_check(dataset[0]["input_ids"], dataset[0]["labels"], tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment