Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
5e522efc
Unverified
Commit
5e522efc
authored
Jan 29, 2021
by
Shaden Smith
Committed by
GitHub
Jan 29, 2021
Browse files
set_batch_fn and remove old sanity check (#712)
parent
2e2dd861
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
11 deletions
+8
-11
deepspeed/runtime/pipe/engine.py
deepspeed/runtime/pipe/engine.py
+8
-11
No files found.
deepspeed/runtime/pipe/engine.py
View file @
5e522efc
...
...
@@ -478,17 +478,6 @@ class PipelineEngine(DeepSpeedEngine):
if
self
.
batch_fn
:
batch
=
self
.
batch_fn
(
batch
)
# Sanity check dimensions.
# XXX: the last minibatch with size < micro_batch_size kills us
if
torch
.
is_tensor
(
batch
[
0
]):
if
batch
[
0
].
size
(
0
)
!=
self
.
micro_batch_size
:
print
(
f
'size mismatch:
{
batch
[
0
].
size
(
0
)
}
mb:
{
self
.
micro_batch_size
}
'
)
return
self
.
_next_batch
()
else
:
assert
torch
.
is_tensor
(
batch
[
0
][
0
])
if
batch
[
0
][
0
].
size
(
0
)
!=
self
.
micro_batch_size
:
return
self
.
_next_batch
()
return
batch
def
_exec_forward_pass
(
self
,
buffer_id
):
...
...
@@ -1170,3 +1159,11 @@ class PipelineEngine(DeepSpeedEngine):
# Equivalent to: self._exec_forward_pass(buffer_id=0)
self
.
_exec_instr
=
MethodType
(
self
.
_INSTRUCTION_MAP
[
type
(
cmd
)],
self
)
self
.
_exec_instr
(
**
cmd
.
kwargs
)
def
set_batch_fn
(
self
,
fn
):
"""Execute a post-processing function on input data.
Args:
fn (function): The function to run.
"""
self
.
batch_fn
=
fn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment