Unverified Commit f7e3f82a authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

fix llama pretrain (#5287)

parent 6a569678
...@@ -273,11 +273,10 @@ def main(): ...@@ -273,11 +273,10 @@ def main():
dataloader.sampler.set_start_index(sampler_start_idx) dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch) dataloader.sampler.set_epoch(epoch)
step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader) dataloader_iter = iter(dataloader)
with tqdm( with tqdm(
range(step_nums), range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}", desc=f"Epoch {epoch}",
disable=not print_flag, disable=not print_flag,
total=num_steps_per_epoch, total=num_steps_per_epoch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment