Commit 14b50ed0 authored by mashun1's avatar mashun1
Browse files

Update train.py

parent d6cb6066
......@@ -126,7 +126,7 @@ def train(opt):
# enable gradient checkpointing to save GPU memory, but this action would slowdown the training speed 20-30%.
# in addition, gradient_checkpointing can not be enabled when using deepspeed ZERO-3
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
dataset = SFTDataset(opt.sft_data_dir, tokenizer, opt.block_size, opt.mode)
if accelerator.is_main_process:
sanity_check(dataset[0]["input_ids"], dataset[0]["labels"], tokenizer)
......@@ -207,4 +207,4 @@ def train(opt):
if __name__ == "__main__":
opt = parse_option()
train(opt)
\ No newline at end of file
train(opt)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment