Unverified Commit bb0a668f authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value
parent 5fcd7795
......@@ -75,7 +75,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(
_preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False
_preprocess_data(data), model, _criterion, optimizer, return_loss=True
)
else:
output = model(**_preprocess_data(data))
......@@ -109,7 +109,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
data_for_origin = data_gen_fn()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
)
booster.execute_pipeline(
_preprocess_data(data_for_origin),
......@@ -117,7 +117,6 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
_criterion,
new_optimizer,
return_loss=True,
return_outputs=False,
)
else:
old_model_loss = criterion(model(**_preprocess_data(data_for_shard)))
......
......@@ -49,7 +49,6 @@ def run_fwd_bwd(
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
......
......@@ -104,7 +104,7 @@ def run_pp(
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
# check loss
......@@ -134,7 +134,7 @@ def run_pp(
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
......
......@@ -100,7 +100,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_loss = criterion(torch_output)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
# check loss
......@@ -130,7 +130,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment