Unverified Commit f1f51990 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix]fix some bugs caused by refactored schedule. (#1148)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [hotfix]fix some bugs caused by refactored schedule.
parent 8cdce039
...@@ -36,7 +36,13 @@ class BaseSchedule(ABC): ...@@ -36,7 +36,13 @@ class BaseSchedule(ABC):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
data = data.to(get_current_device()) data = data.to(get_current_device())
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
data = [self._move_tensor(v) for v in data] data_to_return = []
for element in data:
if isinstance(element, dict):
data_to_return.append({k: self._move_tensor(v) for k, v in element.items()})
else:
data_to_return.append(self._move_tensor(element))
data = data_to_return
elif isinstance(data, dict): elif isinstance(data, dict):
data = {k: self._move_tensor(v) for k, v in data.items()} data = {k: self._move_tensor(v) for k, v in data.items()}
else: else:
......
...@@ -66,7 +66,6 @@ class NonPipelineSchedule(BaseSchedule): ...@@ -66,7 +66,6 @@ class NonPipelineSchedule(BaseSchedule):
assert forward_only or return_loss, \ assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False." "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter) batch_data = self.load_batch(data_iter)
if self.data_process_func: if self.data_process_func:
data, label = self.data_process_func(batch_data) data, label = self.data_process_func(batch_data)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment