Commit d6b2d860 authored by Xiaoliang Dai's avatar Xiaoliang Dai Committed by Facebook GitHub Bot
Browse files

bfloat16 training

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/649

support bfloat16 training

Reviewed By: chihyaoma, Sekunde

Differential Revision: D53029989

fbshipit-source-id: 2e1d8f2112d238441e3f6801db3092383147fdbd
parent 3c6f71b4
...@@ -168,7 +168,7 @@ class EMAUpdater(object): ...@@ -168,7 +168,7 @@ class EMAUpdater(object):
ema_val = self.state.state[name] ema_val = self.state.state[name]
if self.device: if self.device:
val = val.to(self.device) val = val.to(self.device)
if val.dtype in [torch.float32, torch.float16]: if val.dtype in [torch.float32, torch.float16, torch.bfloat16]:
ema_param_list.append(ema_val) ema_param_list.append(ema_val)
param_list.append(val) param_list.append(val)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment