Unverified Commit 2d1f7c01 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[engine] train should be able to get `mode` arg (#571)

parent 845921b3
...@@ -800,12 +800,12 @@ class DeepSpeedEngine(Module): ...@@ -800,12 +800,12 @@ class DeepSpeedEngine(Module):
data_parallel_world_size=data_parallel_world_size, data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank) data_parallel_rank=data_parallel_rank)
def train(self): def train(self, mode=True):
r""" r"""
""" """
self.warn_unscaled_loss = True self.warn_unscaled_loss = True
self.module.train() self.module.train(mode)
def eval(self): def eval(self):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment