Unverified Commit 946dbd62 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix]fix bugs caused by refactored pipeline (#1133)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [hotfix]fix bugs caused by refactored pipeline
parent 789cad30
...@@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule): ...@@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule):
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False." "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter) batch_data = self.load_batch(data_iter)
if self.batch_data_process_func: if self.data_process_func:
data, label = self.batch_data_process_func(batch_data) data, label = self.data_process_func(batch_data)
else: else:
# if not batch data process func is given, # if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label) # then we regard the batch data as a simple tuple of (data, label)
......
...@@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule): ...@@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule):
for element in data: for element in data:
if isinstance(element, dict): if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
elif data_dict:
data_dict['label'] = element[offset:offset + self.microbatch_size]
if data_dict: if data_dict:
return data_dict return data_dict
return [val[offset:offset + self.microbatch_size] for val in data] return [val[offset:offset + self.microbatch_size] for val in data]
...@@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule): ...@@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule):
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
return model(*data) return model(*data)
elif isinstance(data, dict): elif isinstance(data, dict):
return model(**data) stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
return model(stage_output, **data)
else: else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
...@@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule): ...@@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule):
data = stage_output data = stage_output
_, label = micro_batch_data _, label = micro_batch_data
elif isinstance(micro_batch_data, dict): elif isinstance(micro_batch_data, dict):
args = []
data = {} data = {}
label = {} data['stage_output'] = stage_output
if 'label' in micro_batch_data:
# we feed the stage output to args first label = micro_batch_data.pop('label')
# then map each arg in args to its param name else:
if stage_output is not None: label = None
if isinstance(stage_output, torch.Tensor): load_data = micro_batch_data
args.append(stage_output) data.update(load_data)
elif isinstance(stage_output, (list, tuple)):
args.extend(stage_output)
else:
raise TypeError(
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
)
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
if idx < len(args):
data[param_name] = args[idx]
else:
if param_name in micro_batch_data:
data[param_name] = micro_batch_data[param_name]
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
label[param_name] = micro_batch_data[param_name]
return data, label return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
......
...@@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
modified_args = [] modified_args = []
for arg in args: for arg in args:
if isinstance(arg, torch.nn.Module): if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself. # if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
arg = self._layer_spec_dict[id(arg)] # if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id(arg) in self._layer_spec_dict:
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg) modified_args.append(arg)
# to the same for the keyword arguments # to the same for the keyword arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment