Unverified Commit b35a3d3f authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[test] using PyTorch v1.6 for Lint checks (#36)

parent 2f638e5a
...@@ -115,7 +115,7 @@ jobs: ...@@ -115,7 +115,7 @@ jobs:
keys: keys:
- cache-key-cpu-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_15 - <<: *install_dep_16
- save_cache: - save_cache:
paths: paths:
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
from typing import List from typing import List
import numpy as np # type: ignore
import torch import torch
from .utils import ensure_divisibility from .utils import ensure_divisibility
...@@ -68,11 +67,9 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = ...@@ -68,11 +67,9 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) data_parallel_size = int(world_size / (model_parallel_size * pipeline_length))
groups = ( groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size)
torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size).numpy()
)
found = np.where(groups == rank) found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found) assert all(len(x) == 1 for x in found)
found = [x[0] for x in found] found = [x[0] for x in found]
......
...@@ -47,7 +47,7 @@ def _initialize_affine_weight( ...@@ -47,7 +47,7 @@ def _initialize_affine_weight(
in_features: int, in_features: int,
per_partition_size: int, per_partition_size: int,
partition_dim: int, partition_dim: int,
init_method: Callable[[torch.Tensor], None], init_method: Callable[[torch.Tensor], torch.Tensor],
stride: int = 1, stride: int = 1,
return_master_weight: bool = False, return_master_weight: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
...@@ -101,7 +101,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -101,7 +101,7 @@ class VocabParallelEmbedding(torch.nn.Module):
norm_type: float = 2.0, norm_type: float = 2.0,
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False, sparse: bool = False,
init_method: Callable[[torch.Tensor], None] = init.xavier_normal_, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
) -> None: ) -> None:
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions. # Keep the input dimensions.
...@@ -169,7 +169,7 @@ class ParallelEmbedding(torch.nn.Module): ...@@ -169,7 +169,7 @@ class ParallelEmbedding(torch.nn.Module):
norm_type: float = 2.0, norm_type: float = 2.0,
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False, sparse: bool = False,
init_method: Callable[[torch.Tensor], None] = init.xavier_normal_, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
keep_master_weight_for_test: bool = False, keep_master_weight_for_test: bool = False,
) -> None: ) -> None:
super(ParallelEmbedding, self).__init__() super(ParallelEmbedding, self).__init__()
...@@ -242,7 +242,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -242,7 +242,7 @@ class ColumnParallelLinear(torch.nn.Module):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
gather_output: bool = True, gather_output: bool = True,
init_method: Callable[[torch.Tensor], None] = init.xavier_normal_, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1, stride: int = 1,
keep_master_weight_for_test: bool = False, keep_master_weight_for_test: bool = False,
) -> None: ) -> None:
...@@ -326,7 +326,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -326,7 +326,7 @@ class RowParallelLinear(torch.nn.Module):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
input_is_parallel: bool = False, input_is_parallel: bool = False,
init_method: Callable[[torch.Tensor], None] = init.xavier_normal_, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1, stride: int = 1,
keep_master_weight_for_test: bool = False, keep_master_weight_for_test: bool = False,
): ):
......
...@@ -30,6 +30,7 @@ from . import nn as nn ...@@ -30,6 +30,7 @@ from . import nn as nn
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
from . import backends from . import backends
from . import distributed
from . import version from . import version
#END #END
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment