Unverified Commit 44ca61a2 authored by Camille Zhong's avatar Camille Zhong Committed by GitHub
Browse files

[llama] fix neftune & pbar with start_step (#5364)

parent a4cec171
......@@ -17,7 +17,7 @@ import torch
def unwrap(model):
if hasattr(model, "module"):
return unwrap_model(model.module)
return model.unwrap()
else:
return model
......
......@@ -329,9 +329,9 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader):
for step, batch in enumerate(dataloader, start=start_step):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment