Unverified Commit 946dbd62 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix]fix bugs caused by refactored pipeline (#1133)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [hotfix]fix bugs caused by refactored pipeline
parent 789cad30
......@@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule):
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)
......
......@@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule):
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
elif data_dict:
data_dict['label'] = element[offset:offset + self.microbatch_size]
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
......@@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule):
elif isinstance(data, (list, tuple)):
return model(*data)
elif isinstance(data, dict):
return model(**data)
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
return model(stage_output, **data)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
......@@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule):
data = stage_output
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
args = []
data = {}
label = {}
# we feed the stage output to args first
# then map each arg in args to its param name
if stage_output is not None:
if isinstance(stage_output, torch.Tensor):
args.append(stage_output)
elif isinstance(stage_output, (list, tuple)):
args.extend(stage_output)
else:
raise TypeError(
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
)
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
if idx < len(args):
data[param_name] = args[idx]
data['stage_output'] = stage_output
if 'label' in micro_batch_data:
label = micro_batch_data.pop('label')
else:
if param_name in micro_batch_data:
data[param_name] = micro_batch_data[param_name]
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
label[param_name] = micro_batch_data[param_name]
label = None
load_data = micro_batch_data
data.update(load_data)
return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
......
......@@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id(arg) in self._layer_spec_dict:
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
# to the same for the keyword arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment