Unverified Commit 045a9743 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Move experimental folder to the fairscale repo (#410)



* move experimental to the fairscale repo

* lint error fixes

* modify test imports

* lint error fixes

* lint errors
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 8fd82858
...@@ -18,7 +18,7 @@ from torch.utils.data import DataLoader ...@@ -18,7 +18,7 @@ from torch.utils.data import DataLoader
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule from fairscale.nn.pipe import LazyModule
......
...@@ -82,7 +82,7 @@ class AsyncAMPnetEventLoop: ...@@ -82,7 +82,7 @@ class AsyncAMPnetEventLoop:
self.checkpoint_stop = checkpoint_stop self.checkpoint_stop = checkpoint_stop
self.input_device = input_device self.input_device = input_device
def perform_optimizer_step(self, optimizer, num_gradients): def perform_optimizer_step(self, optimizer: Any, num_gradients: Any) -> Any:
return (optimizer is not None) and ((num_gradients % self.min_update_interval == 0) or self.weight_prediction) return (optimizer is not None) and ((num_gradients % self.min_update_interval == 0) or self.weight_prediction)
def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]: def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]:
......
...@@ -25,7 +25,7 @@ class AMPnetPipe(AsyncPipe): ...@@ -25,7 +25,7 @@ class AMPnetPipe(AsyncPipe):
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786 The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
""" """
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
def interleave( def interleave(
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment