Unverified Commit 44ca61a2 authored by Camille Zhong's avatar Camille Zhong Committed by GitHub
Browse files

[llama] fix neftune & pbar with start_step (#5364)

parent a4cec171
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
def unwrap(model): def unwrap(model):
if hasattr(model, "module"): if hasattr(model, "module"):
return unwrap_model(model.module) return model.unwrap()
else: else:
return model return model
......
...@@ -329,9 +329,9 @@ def main() -> None: ...@@ -329,9 +329,9 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch) dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch) pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
total_loss = torch.tensor(0.0, device=get_current_device()) total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader, start=start_step):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch) batch_output = model(**batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment