Unverified Commit 8c405c51 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

Fixed RNN support for containers (#494)



* Fix packed sequence apply

* Update fairscale/utils/containers.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 2e9a14e7
...@@ -7,12 +7,13 @@ from collections import OrderedDict ...@@ -7,12 +7,13 @@ from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
from torch.nn.utils.rnn import PackedSequence
"""Useful functions to deal with tensor types with other python container types.""" """Useful functions to deal with tensor types with other python container types."""
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in 4 kinds of container types.""" """Recursively apply to all tensor in different kinds of container types."""
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x): if torch.is_tensor(x):
...@@ -22,6 +23,9 @@ def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tu ...@@ -22,6 +23,9 @@ def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tu
for key, value in x.items(): for key, value in x.items():
od[key] = _apply(value) od[key] = _apply(value)
return od return od
elif isinstance(x, PackedSequence):
_apply(x.data)
return x
elif isinstance(x, dict): elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()} return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list): elif isinstance(x, list):
......
...@@ -14,6 +14,7 @@ import random ...@@ -14,6 +14,7 @@ import random
import pytest import pytest
import torch import torch
import torch.nn as nn
from fairscale.utils.containers import ( from fairscale.utils.containers import (
apply_to_tensors, apply_to_tensors,
...@@ -129,3 +130,20 @@ def test_split_unpack(): ...@@ -129,3 +130,20 @@ def test_split_unpack():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# assert the content of the second arg should be sane. # assert the content of the second arg should be sane.
recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []}) recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})
def test_packed_sequence():
"""Test to ensure RNN packed sequences are modified correctly."""
rnn = nn.RNN(5, 5)
x = torch.rand((5, 1, 5), dtype=torch.float)
seq_length = torch.tensor([4], dtype=torch.int)
def fill_fn(x):
x.fill_(0)
x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
x, h = rnn(x)
x = apply_to_tensors(fill_fn, x)
x, _ = nn.utils.rnn.pad_packed_sequence(x)
assert torch.sum(x) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment