Commit cb8160d3 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Register RNN-T pipeline global stats constants as buffers (#2175)

Summary:
Currently, `mean` and `invstddev` exist as vanilla object attributes in the global stats normalization module that uses them. This, however, would preclude them from being moved to the same device that the module is moved to. To resolve this, this PR registers them as buffers.

Pull Request resolved: https://github.com/pytorch/audio/pull/2175

Reviewed By: nateanl

Differential Revision: D33794239

Pulled By: hwangjeff

fbshipit-source-id: 78eb699ab5e0844f9436afc529b851e651f4f451
parent 691317a9
......@@ -52,8 +52,8 @@ class _GlobalStatsNormalization(torch.nn.Module):
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
self.register_buffer("mean", torch.tensor(blob["mean"]))
self.register_buffer("invstddev", torch.tensor(blob["invstddev"]))
def forward(self, input):
return (input - self.mean) * self.invstddev
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment