Unverified Commit 03a25ba8 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Casting logic should reflatten RNN parameters

parent 1b8303d8
......@@ -65,6 +65,8 @@ def convert_network(network, dtype):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):
module.flatten_parameters()
return network
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment