Unverified Commit bb0a668f authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value
parent 5fcd7795
...@@ -75,7 +75,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf ...@@ -75,7 +75,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
model.train() model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
booster.execute_pipeline( booster.execute_pipeline(
_preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False _preprocess_data(data), model, _criterion, optimizer, return_loss=True
) )
else: else:
output = model(**_preprocess_data(data)) output = model(**_preprocess_data(data))
...@@ -109,7 +109,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf ...@@ -109,7 +109,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
data_for_origin = data_gen_fn() data_for_origin = data_gen_fn()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
booster.execute_pipeline( booster.execute_pipeline(
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
) )
booster.execute_pipeline( booster.execute_pipeline(
_preprocess_data(data_for_origin), _preprocess_data(data_for_origin),
...@@ -117,7 +117,6 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf ...@@ -117,7 +117,6 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
_criterion, _criterion,
new_optimizer, new_optimizer,
return_loss=True, return_loss=True,
return_outputs=False,
) )
else: else:
old_model_loss = criterion(model(**_preprocess_data(data_for_shard))) old_model_loss = criterion(model(**_preprocess_data(data_for_shard)))
......
...@@ -49,7 +49,6 @@ def run_fwd_bwd( ...@@ -49,7 +49,6 @@ def run_fwd_bwd(
lambda x, y: x.loss, lambda x, y: x.loss,
optimizer, optimizer,
return_loss=True, return_loss=True,
return_outputs=True,
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
......
...@@ -104,7 +104,7 @@ def run_pp( ...@@ -104,7 +104,7 @@ def run_pp(
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
# check loss # check loss
...@@ -134,7 +134,7 @@ def run_pp( ...@@ -134,7 +134,7 @@ def run_pp(
torch_loss = criterion(torch_output) torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
......
...@@ -100,7 +100,7 @@ def examine_pp(num_microbatch: int, batch_size: int): ...@@ -100,7 +100,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_loss = criterion(torch_output) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
# check loss # check loss
...@@ -130,7 +130,7 @@ def examine_pp(num_microbatch: int, batch_size: int): ...@@ -130,7 +130,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_loss = criterion(torch_output) torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment