Unverified Commit 2d1f7c01 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[engine] train should be able to get `mode` arg (#571)

parent 845921b3
......@@ -800,12 +800,12 @@ class DeepSpeedEngine(Module):
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank)
def train(self):
def train(self, mode=True):
r"""
"""
self.warn_unscaled_loss = True
self.module.train()
self.module.train(mode)
def eval(self):
r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment