Unverified Commit 5e522efc authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

set_batch_fn and remove old sanity check (#712)

parent 2e2dd861
...@@ -478,17 +478,6 @@ class PipelineEngine(DeepSpeedEngine): ...@@ -478,17 +478,6 @@ class PipelineEngine(DeepSpeedEngine):
if self.batch_fn: if self.batch_fn:
batch = self.batch_fn(batch) batch = self.batch_fn(batch)
# Sanity check dimensions.
# XXX: the last minibatch with size < micro_batch_size kills us
if torch.is_tensor(batch[0]):
if batch[0].size(0) != self.micro_batch_size:
print(f'size mismatch: {batch[0].size(0)} mb: {self.micro_batch_size}')
return self._next_batch()
else:
assert torch.is_tensor(batch[0][0])
if batch[0][0].size(0) != self.micro_batch_size:
return self._next_batch()
return batch return batch
def _exec_forward_pass(self, buffer_id): def _exec_forward_pass(self, buffer_id):
...@@ -1170,3 +1159,11 @@ class PipelineEngine(DeepSpeedEngine): ...@@ -1170,3 +1159,11 @@ class PipelineEngine(DeepSpeedEngine):
# Equivalent to: self._exec_forward_pass(buffer_id=0) # Equivalent to: self._exec_forward_pass(buffer_id=0)
self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self) self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
self._exec_instr(**cmd.kwargs) self._exec_instr(**cmd.kwargs)
def set_batch_fn(self, fn):
"""Execute a post-processing function on input data.
Args:
fn (function): The function to run.
"""
self.batch_fn = fn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment