Unverified Commit 8c405c51 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

Fixed RNN support for containers (#494)



* Fix packed sequence apply

* Update fairscale/utils/containers.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 2e9a14e7
......@@ -7,12 +7,13 @@ from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch.nn.utils.rnn import PackedSequence
"""Useful functions to deal with tensor types with other python container types."""
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in 4 kinds of container types."""
"""Recursively apply to all tensor in different kinds of container types."""
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
......@@ -22,6 +23,9 @@ def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tu
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, PackedSequence):
_apply(x.data)
return x
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
......
......@@ -14,6 +14,7 @@ import random
import pytest
import torch
import torch.nn as nn
from fairscale.utils.containers import (
apply_to_tensors,
......@@ -129,3 +130,20 @@ def test_split_unpack():
with pytest.raises(AssertionError):
# assert the content of the second arg should be sane.
recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})
def test_packed_sequence():
"""Test to ensure RNN packed sequences are modified correctly."""
rnn = nn.RNN(5, 5)
x = torch.rand((5, 1, 5), dtype=torch.float)
seq_length = torch.tensor([4], dtype=torch.int)
def fill_fn(x):
x.fill_(0)
x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
x, h = rnn(x)
x = apply_to_tensors(fill_fn, x)
x, _ = nn.utils.rnn.pad_packed_sequence(x)
assert torch.sum(x) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment