Unverified Commit 3490b9e1 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Allow for different backend for Pipeline Parallel ProcessGroups (#1380)



* NcclDistributedTestBase

* fix stupid mistake

* add UCC test

* add UCC backend

* torch ucc tests

* allows for UCC backend

* Set `UCX_TLS` to `tcp,cuda_copy` & Use DDP iff it makes sense

* Apply 4 suggestion(s) to 1 file(s)

* mix&match NCCL & UCC

* use both ucc&nccl in gpt

* UCC for Pipeline Parallel, NCCL for the others

* conditionally use ucc

* make ucc guards more friendly

* test raises when torch_ucc isn't available

* Change to member variable from class variable
Co-authored-by: default avatarAidyn Aitzhan <31858918+Aidyn-A@users.noreply.github.com>

* pass async_comm to train, I mistakenly dropped it during the rebase

* fix typo: functionality

* Enable tensor parallel only when device count > 4

I want pipeline model parallel world size to be >= 4 because
previously I saw GPT/BERT failing when only UCC is used.
So I'm speculating that there's some gotcha around pipeline size of 4.

* Add nvidia driver version guard
Co-authored-by: default avatarAidyn Aitzhan <31858918+Aidyn-A@users.noreply.github.com>

* move world_size as it was not correctly reflected

* keep eye on the nvml api thing

* import unittest
Co-authored-by: default avatarAidyn Aitzhan <31858918+Aidyn-A@users.noreply.github.com>
parent d36397d2
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py # TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups.""" """Model and data parallel groups."""
from typing import Tuple, Optional from typing import Tuple, Optional
import warnings
import torch import torch
...@@ -75,6 +76,9 @@ def initialize_model_parallel( ...@@ -75,6 +76,9 @@ def initialize_model_parallel(
pipeline_model_parallel_size_: int = 1, pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None, virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None, pipeline_model_parallel_split_rank_: Optional[int] = None,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None: ) -> None:
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -84,6 +88,15 @@ def initialize_model_parallel( ...@@ -84,6 +88,15 @@ def initialize_model_parallel(
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline). virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point. pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Keyword Arguments:
default_backend: Backend of process groups except for pipeline parallel ones.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
p2p_backend: Backend of process groups for pipeline model parallel.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
.. note::
`torch_ucc <https://github.com/facebookresearch/torch_ucc>`_ is
necessary for "ucc" backend.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -103,6 +116,14 @@ def initialize_model_parallel( ...@@ -103,6 +116,14 @@ def initialize_model_parallel(
""" """
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
assert default_backend is None or default_backend in ("nccl", "ucc")
assert p2p_backend is None or p2p_backend in ("nccl", "ucc")
if "ucc" in (default_backend, p2p_backend):
check_torch_ucc_availability()
warnings.warn("`ucc` backend support is experimental", ExperimentalWarning)
if default_backend == "ucc":
warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning)
world_size: int = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size) tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size) pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size)
...@@ -160,7 +181,7 @@ def initialize_model_parallel( ...@@ -160,7 +181,7 @@ def initialize_model_parallel(
for j in range(tensor_model_parallel_size): for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks)) all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks: if rank in ranks:
_DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP = group
...@@ -172,7 +193,7 @@ def initialize_model_parallel( ...@@ -172,7 +193,7 @@ def initialize_model_parallel(
data_parallel_group_ranks[i] data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks for data_parallel_group_ranks in all_data_parallel_group_ranks
] ]
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks: if rank in ranks:
_MODEL_PARALLEL_GROUP = group _MODEL_PARALLEL_GROUP = group
...@@ -185,7 +206,7 @@ def initialize_model_parallel( ...@@ -185,7 +206,7 @@ def initialize_model_parallel(
ranks = list( ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
) )
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks: if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group _TENSOR_MODEL_PARALLEL_GROUP = group
...@@ -206,7 +227,7 @@ def initialize_model_parallel( ...@@ -206,7 +227,7 @@ def initialize_model_parallel(
), "position embedding group is already initialized" ), "position embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups) ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=p2p_backend)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks _PIPELINE_GLOBAL_RANKS = ranks
...@@ -234,13 +255,13 @@ def initialize_model_parallel( ...@@ -234,13 +255,13 @@ def initialize_model_parallel(
embedding_ranks = ranks embedding_ranks = ranks
position_embedding_ranks = ranks position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks, backend=default_backend)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks: if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks _EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks) group = torch.distributed.new_group(position_embedding_ranks, backend=default_backend)
if rank in position_embedding_ranks: if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group _POSITION_EMBEDDING_GROUP = group
if rank in ranks: if rank in ranks:
...@@ -578,3 +599,16 @@ def destroy_model_parallel(): ...@@ -578,3 +599,16 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# Used to warn when the UCC is specified.
class ExperimentalWarning(Warning): pass
def check_torch_ucc_availability() -> None:
try:
import torch_ucc # NOQA
except ImportError:
raise ImportError(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
import os import os
import random import random
from typing import Optional, Union, List from typing import Optional, Union, List
...@@ -117,6 +118,10 @@ def initialize_distributed(backend="nccl"): ...@@ -117,6 +118,10 @@ def initialize_distributed(backend="nccl"):
# parser.add_argument('--local_rank', type=int, default=None, # parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher') # help='local rank passed from distributed launcher')
# args = parser.parse_args() # args = parser.parse_args()
if backend not in ("nccl", "ucc"):
raise RuntimeError(f"Currently only nccl & ucc are supported but {backend}")
if backend == "ucc":
import torch_ucc # NOQA
args = global_vars.get_args() args = global_vars.get_args()
local_rank = args.local_rank local_rank = args.local_rank
...@@ -141,7 +146,8 @@ def initialize_distributed(backend="nccl"): ...@@ -141,7 +146,8 @@ def initialize_distributed(backend="nccl"):
master_port = os.getenv("MASTER_PORT", "6000") master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port init_method += master_ip + ":" + master_port
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, world_size=world_size, rank=rank, init_method=init_method backend=backend, world_size=world_size, rank=rank, init_method=init_method,
timeout=datetime.timedelta(seconds=60),
) )
......
import os
import sys import sys
import unittest
import torch import torch
from torch import distributed as dist from torch import distributed as dist
from torch.utils import collect_env
from torch.testing._internal import common_utils from torch.testing._internal import common_utils
from torch.testing._internal import common_distributed from torch.testing._internal import common_distributed
HAS_TORCH_UCC = None
try:
import torch_ucc
HAS_TORCH_UCC = True
except ImportError:
HAS_TORCH_UCC = False
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = "470.42.01"
_driver_version = None
if torch.cuda.is_available():
_driver_version = collect_env.get_nvidia_driver_version(collect_env.run)
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
class DistributedTestBase(common_distributed.MultiProcessTestCase):
BACKEND_NCCL = "nccl" class DistributedTestBase(common_distributed.MultiProcessTestCase):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self._setup_pre_spawn()
self._spawn_processes() self._spawn_processes()
def tearDown(self) -> None: def tearDown(self) -> None:
...@@ -32,6 +48,7 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase): ...@@ -32,6 +48,7 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase):
def _run(cls, rank, test_name, file_name, pipe): def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name) self = cls(test_name)
self.assertTrue(torch.cuda.is_available()) self.assertTrue(torch.cuda.is_available())
self.assertTrue(hasattr(self, "DISTRIBUTED_BACKEND"))
self.rank = rank self.rank = rank
self.file_name = file_name self.file_name = file_name
...@@ -40,13 +57,13 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase): ...@@ -40,13 +57,13 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase):
try: try:
dist.init_process_group( dist.init_process_group(
init_method=self.init_method, init_method=self.init_method,
backend=DistributedTestBase.BACKEND_NCCL, backend=self.DISTRIBUTED_BACKEND,
world_size=int(self.world_size), world_size=int(self.world_size),
rank=self.rank, rank=self.rank,
) )
except RuntimeError as e: except RuntimeError as e:
if "recompile" in e.args[0]: if "recompile" in e.args[0]:
print(f"Backend of {DistributedTestBase.BACKEND_NCCL} not available") print(f"Backend of {self.DISTRIBUTED_BACKEND} not available")
sys.exit(0) sys.exit(0)
raise raise
...@@ -58,3 +75,55 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase): ...@@ -58,3 +75,55 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase):
dist.destroy_process_group() dist.destroy_process_group()
sys.exit(0) sys.exit(0)
def _setup_pre_spawn(self):
pass
class NcclDistributedTestBase(DistributedTestBase):
DISTRIBUTED_BACKEND = "nccl"
@unittest.skipUnless(
HAS_TORCH_UCC,
"Requires [`torch_ucc`](https://github.com/facebookresearch/torch_ucc)",
)
@unittest.skipUnless(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
"See https://github.com/openucx/ucc/issues/496",
)
class UccDistributedTestBase(DistributedTestBase):
DISTRIBUTED_BACKEND = "ucc"
def _setup_pre_spawn(self) -> None:
self.master_addr = "localhost"
os.environ["MASTER_ADDR"] = "localhost"
self._has_master_port = "MASTER_PORT" in os.environ
if self._has_master_port:
self.master_port = os.environ["MASTER_PORT"]
else:
try:
from caffe2.torch.fb.common.utils import get_free_port
self.master_port = str(get_free_port())
except ImportError:
self.master_port = "12375"
os.environ["MASTER_PORT"] = self.master_port
self._has_ucx_tls = "UCX_TLS" in os.environ
if not self._has_ucx_tls:
os.environ["UCX_TLS"] = "tcp,cuda_copy"
print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ["UCX_TLS"]))
def tearDown(self) -> None:
super().tearDown()
if not self._has_master_port:
del os.environ["MASTER_PORT"]
if not self._has_ucx_tls:
del os.environ["UCX_TLS"]
@property
def init_method(self):
return "tcp://localhost:" + os.environ["MASTER_PORT"]
import random import random
import torch import torch
try:
import torch_ucc
except ImportError:
HAS_TORCH_UCC = False
else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.log_util import set_logging_level
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import ( from apex.transformer.pipeline_parallel.utils import (
...@@ -27,6 +35,7 @@ class DebugWarning(Warning): ...@@ -27,6 +35,7 @@ class DebugWarning(Warning):
pass pass
set_logging_level("WARNING")
mode = None mode = None
MANUAL_SEED = 42 MANUAL_SEED = 42
inds = None inds = None
...@@ -154,7 +163,7 @@ if __name__ == "__main__": ...@@ -154,7 +163,7 @@ if __name__ == "__main__":
effective_length = fancy_data.size(0) // global_vars.get_args().seq_length effective_length = fancy_data.size(0) // global_vars.get_args().seq_length
effective_length = fancy_data.size(0) - global_vars.get_args().seq_length effective_length = fancy_data.size(0) - global_vars.get_args().seq_length
initialize_distributed() initialize_distributed("nccl")
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
failure = None failure = None
init = True init = True
...@@ -182,6 +191,8 @@ if __name__ == "__main__": ...@@ -182,6 +191,8 @@ if __name__ == "__main__":
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size,
default_backend="nccl",
p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
) )
pipeline_model_parallel_size = ( pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size() parallel_state.get_pipeline_model_parallel_world_size()
......
...@@ -3,6 +3,13 @@ from typing import List ...@@ -3,6 +3,13 @@ from typing import List
import time import time
import torch import torch
try:
import torch_ucc
except ImportError:
HAS_TORCH_UCC = False
else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
...@@ -92,6 +99,7 @@ def loss_func(loss_mask, output_tensor): ...@@ -92,6 +99,7 @@ def loss_func(loss_mask, output_tensor):
# Reduce loss for logging. # Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]} return loss, {"lm loss": averaged_loss[0]}
...@@ -144,12 +152,13 @@ if __name__ == "__main__": ...@@ -144,12 +152,13 @@ if __name__ == "__main__":
if init: if init:
init = False init = False
global_vars.set_global_variables() global_vars.set_global_variables()
args = global_vars.get_args()
fancy_data = download_fancy_data() fancy_data = download_fancy_data()
args = global_vars.get_args()
effective_length = fancy_data.size(0) // args.seq_length effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length effective_length = fancy_data.size(0) - args.seq_length
initialize_distributed() initialize_distributed("nccl")
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
failure = None failure = None
...@@ -170,6 +179,8 @@ if __name__ == "__main__": ...@@ -170,6 +179,8 @@ if __name__ == "__main__":
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size, tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size, pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
default_backend="nccl",
p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
) )
pipeline_model_parallel_size = ( pipeline_model_parallel_size = (
......
...@@ -11,7 +11,8 @@ from apex.transformer import parallel_state ...@@ -11,7 +11,8 @@ from apex.transformer import parallel_state
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel import cross_entropy from apex.transformer.tensor_parallel import cross_entropy
from apex.transformer.testing.commons import set_random_seed, IdentityLayer from apex.transformer.testing.commons import set_random_seed, IdentityLayer
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
...@@ -54,7 +55,7 @@ def tensor_sharded_cross_entropy( ...@@ -54,7 +55,7 @@ def tensor_sharded_cross_entropy(
return loss, identity.weight.grad return loss, identity.weight.grad
class VocabParallelCrossEntropy(DistributedTestBase): class VocabParallelCrossEntropyTestBase:
def test_cross_entropy(self): def test_cross_entropy(self):
batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11 batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
logits_scale = 1000.0 logits_scale = 1000.0
...@@ -85,5 +86,9 @@ class VocabParallelCrossEntropy(DistributedTestBase): ...@@ -85,5 +86,9 @@ class VocabParallelCrossEntropy(DistributedTestBase):
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): pass
class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING) ...@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING)
class BroadcastDataTest(DistributedTestBase): class BroadcastDataTestBase:
def test_broadcast_data(self): def test_broadcast_data(self):
tensor_model_parallel_world_size: int = self.world_size // ( tensor_model_parallel_world_size: int = self.world_size // (
1 + self.world_size > 1 1 + self.world_size > 1
...@@ -55,5 +56,9 @@ class BroadcastDataTest(DistributedTestBase): ...@@ -55,5 +56,9 @@ class BroadcastDataTest(DistributedTestBase):
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass
class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -9,7 +9,8 @@ logging.getLogger("torch").setLevel(logging.WARNING) ...@@ -9,7 +9,8 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
...@@ -20,7 +21,7 @@ logging.getLogger("apex").setLevel(logging.WARNING) ...@@ -20,7 +21,7 @@ logging.getLogger("apex").setLevel(logging.WARNING)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTest(DistributedTestBase): class TensorParallelLayerTestBase:
BATCH_SIZE: int = 17 BATCH_SIZE: int = 17
SEQUENCE_LENGTH: int = 23 SEQUENCE_LENGTH: int = 23
...@@ -40,29 +41,29 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -40,29 +41,29 @@ class TensorParallelLayerTest(DistributedTestBase):
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size, tensor_model_parallel_size_=tensor_model_parallel_world_size,
) )
set_random_seed(TensorParallelLayerTest.SEED + 1) set_random_seed(self.SEED + 1)
input_tensor = torch.randint( input_tensor = torch.randint(
0, 0,
TensorParallelLayerTest.VOCAB_SIZE, self.VOCAB_SIZE,
( (
TensorParallelLayerTest.BATCH_SIZE, self.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH, self.SEQUENCE_LENGTH,
), ),
device="cuda", device="cuda",
) )
loss_weight = torch.randn( loss_weight = torch.randn(
( (
TensorParallelLayerTest.BATCH_SIZE, self.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH, self.SEQUENCE_LENGTH,
TensorParallelLayerTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
), ),
device="cuda", device="cuda",
) )
set_random_seed(TensorParallelLayerTest.SEED) set_random_seed(self.SEED)
embedding_torch = nn.Embedding( embedding_torch = nn.Embedding(
TensorParallelLayerTest.VOCAB_SIZE, self.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
).cuda() ).cuda()
output_torch = embedding_torch(input_tensor) output_torch = embedding_torch(input_tensor)
loss_torch = torch.mul(output_torch, loss_weight).sum() loss_torch = torch.mul(output_torch, loss_weight).sum()
...@@ -71,10 +72,10 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -71,10 +72,10 @@ class TensorParallelLayerTest(DistributedTestBase):
# N.B. (mkozuki): With affine weight initialization on GPU, # N.B. (mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding. # it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`. # Thus, turning on `use_cpu_initialization`.
set_random_seed(TensorParallelLayerTest.SEED) set_random_seed(self.SEED)
embedding_vocab_parallel = layers.VocabParallelEmbedding( embedding_vocab_parallel = layers.VocabParallelEmbedding(
TensorParallelLayerTest.VOCAB_SIZE, self.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
init_method=nn.init.normal_, init_method=nn.init.normal_,
use_cpu_initialization=True, use_cpu_initialization=True,
).cuda() ).cuda()
...@@ -89,7 +90,7 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -89,7 +90,7 @@ class TensorParallelLayerTest(DistributedTestBase):
splitted_weight_torch = torch.split( splitted_weight_torch = torch.split(
embedding_torch.weight.grad, embedding_torch.weight.grad,
TensorParallelLayerTest.VOCAB_SIZE self.VOCAB_SIZE
// tensor_model_parallel_world_size, // tensor_model_parallel_world_size,
0, 0,
)[parallel_state.get_tensor_model_parallel_rank()] )[parallel_state.get_tensor_model_parallel_rank()]
...@@ -112,21 +113,21 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -112,21 +113,21 @@ class TensorParallelLayerTest(DistributedTestBase):
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size tensor_model_parallel_size_=tensor_model_parallel_world_size
) )
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
weight_shape = ( weight_shape = (
(TensorParallelLayerTest.OUTPUT_SIZE_COEFF, input_size) (self.OUTPUT_SIZE_COEFF, input_size)
if is_column_parallel if is_column_parallel
else (output_size, TensorParallelLayerTest.INPUT_SIZE_COEFF) else (output_size, self.INPUT_SIZE_COEFF)
) )
weight = torch.empty(weight_shape) weight = torch.empty(weight_shape)
set_random_seed(TensorParallelLayerTest.SEED) set_random_seed(self.SEED)
sharding_dim_size = ( sharding_dim_size = (
TensorParallelLayerTest.OUTPUT_SIZE_COEFF self.OUTPUT_SIZE_COEFF
if is_column_parallel if is_column_parallel
else TensorParallelLayerTest.INPUT_SIZE_COEFF else self.INPUT_SIZE_COEFF
) )
if init_device == "cpu": if init_device == "cpu":
...@@ -144,7 +145,7 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -144,7 +145,7 @@ class TensorParallelLayerTest(DistributedTestBase):
weight, torch.nn.init.normal_, dim weight, torch.nn.init.normal_, dim
) )
# Target # Target
set_random_seed(TensorParallelLayerTest.SEED) set_random_seed(self.SEED)
if init_device == "cpu": if init_device == "cpu":
main_weight = torch.empty(output_size, input_size) main_weight = torch.empty(output_size, input_size)
nn.init.normal_(main_weight) nn.init.normal_(main_weight)
...@@ -180,10 +181,10 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -180,10 +181,10 @@ class TensorParallelLayerTest(DistributedTestBase):
tensor_model_parallel_size_=tensor_model_parallel_world_size tensor_model_parallel_size_=tensor_model_parallel_world_size
) )
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
set_random_seed(TensorParallelLayerTest.SEED) set_random_seed(self.SEED)
linear_layer = layers.RowParallelLinear( linear_layer = layers.RowParallelLinear(
input_size, input_size,
output_size, output_size,
...@@ -192,12 +193,12 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -192,12 +193,12 @@ class TensorParallelLayerTest(DistributedTestBase):
use_cpu_initialization=True, use_cpu_initialization=True,
).cuda() ).cuda()
loss_weight = torch.randn( loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, output_size) (self.BATCH_SIZE, output_size)
).cuda() ).cuda()
# Forward and backward # Forward and backward
input_tensor = torch.randn( input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE, input_size, requires_grad=True self.BATCH_SIZE, input_size, requires_grad=True
).cuda() ).cuda()
input_tensor.retain_grad() input_tensor.retain_grad()
output, _ = linear_layer(input_tensor) output, _ = linear_layer(input_tensor)
...@@ -211,13 +212,13 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -211,13 +212,13 @@ class TensorParallelLayerTest(DistributedTestBase):
a = linear_layer.master_weight.cuda() a = linear_layer.master_weight.cuda()
dlda = torch.matmul(dldy.t(), x) dlda = torch.matmul(dldy.t(), x)
dldb = torch.matmul( dldb = torch.matmul(
torch.ones(TensorParallelLayerTest.BATCH_SIZE, 1).cuda().t(), dldy torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy
).view(-1) ).view(-1)
dldx = torch.matmul(dldy, a) dldx = torch.matmul(dldy, a)
with torch.no_grad(): with torch.no_grad():
curr_dlda = torch.split( curr_dlda = torch.split(
dlda, TensorParallelLayerTest.INPUT_SIZE_COEFF, dim=1 dlda, self.INPUT_SIZE_COEFF, dim=1
)[parallel_state.get_tensor_model_parallel_rank()].clone() )[parallel_state.get_tensor_model_parallel_rank()].clone()
self.assertEqual(linear_layer.weight.grad, curr_dlda) self.assertEqual(linear_layer.weight.grad, curr_dlda)
self.assertEqual(input_tensor.grad, dldx) self.assertEqual(input_tensor.grad, dldx)
...@@ -240,9 +241,6 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -240,9 +241,6 @@ class TensorParallelLayerTest(DistributedTestBase):
gradient_accumulation_fusion: bool, gradient_accumulation_fusion: bool,
): ):
for tensor_model_parallel_world_size in range(1, self.world_size + 1): for tensor_model_parallel_world_size in range(1, self.world_size + 1):
print(
f"tensor_model_parallel_world_size={tensor_model_parallel_world_size}"
)
with self.subTest( with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size tensor_model_parallel_world_size=tensor_model_parallel_world_size
): ):
...@@ -252,13 +250,13 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -252,13 +250,13 @@ class TensorParallelLayerTest(DistributedTestBase):
tensor_model_parallel_size_=tensor_model_parallel_world_size, tensor_model_parallel_size_=tensor_model_parallel_world_size,
) )
feature_size_coeff = TensorParallelLayerTest.INPUT_SIZE_COEFF feature_size_coeff = self.INPUT_SIZE_COEFF
feature_size = feature_size_coeff * tensor_model_parallel_world_size feature_size = feature_size_coeff * tensor_model_parallel_world_size
hidden_size = feature_size hidden_size = feature_size
set_random_seed(TensorParallelLayerTest.SEED) set_random_seed(self.SEED)
input_tensor = torch.randn( input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE, self.BATCH_SIZE,
hidden_size, hidden_size,
feature_size, feature_size,
device="cuda", device="cuda",
...@@ -266,7 +264,7 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -266,7 +264,7 @@ class TensorParallelLayerTest(DistributedTestBase):
) )
input_tensor.retain_grad() input_tensor.retain_grad()
loss_weight = torch.randn( loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,), (self.BATCH_SIZE, hidden_size, feature_size,),
device="cuda", device="cuda",
) )
linear = layers.ColumnParallelLinear( linear = layers.ColumnParallelLinear(
...@@ -285,7 +283,7 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -285,7 +283,7 @@ class TensorParallelLayerTest(DistributedTestBase):
output, _ = linear(input_tensor) output, _ = linear(input_tensor)
self.assertEqual( self.assertEqual(
output.shape, output.shape,
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,), (self.BATCH_SIZE, hidden_size, feature_size,),
) )
loss = torch.mul(output, loss_weight).sum() loss = torch.mul(output, loss_weight).sum()
loss.backward() loss.backward()
...@@ -296,7 +294,7 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -296,7 +294,7 @@ class TensorParallelLayerTest(DistributedTestBase):
a = linear.master_weight.cuda().clone() a = linear.master_weight.cuda().clone()
dldx = torch.matmul(dldy, a) dldx = torch.matmul(dldy, a)
self.assertEqual(input_tensor.grad, dldx) self.assertEqual(input_tensor.grad, dldx)
# TODO (mkozuki): Cover the other cases. # TODO(mkozuki): Cover the other cases.
if ( if (
tensor_model_parallel_world_size == 1 tensor_model_parallel_world_size == 1
and not gradient_accumulation_fusion and not gradient_accumulation_fusion
...@@ -310,5 +308,13 @@ class TensorParallelLayerTest(DistributedTestBase): ...@@ -310,5 +308,13 @@ class TensorParallelLayerTest(DistributedTestBase):
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase):
pass
class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase):
pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING) ...@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
class MappingTest(DistributedTestBase): class MappingTestBase:
def test_reduce(self): def test_reduce(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1): for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0: if self.world_size % tensor_model_paralell_world_size > 0:
...@@ -80,5 +81,9 @@ class MappingTest(DistributedTestBase): ...@@ -80,5 +81,9 @@ class MappingTest(DistributedTestBase):
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -13,12 +13,13 @@ from apex.transformer.pipeline_parallel.utils import ( ...@@ -13,12 +13,13 @@ from apex.transformer.pipeline_parallel.utils import (
get_current_global_batch_size, get_current_global_batch_size,
update_num_microbatches, update_num_microbatches,
) )
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
class MicrobatchCalculatorTest(DistributedTestBase): class MicrobatchCalculatorTestBase:
GLOBAL_BATCH_SIZE: int = 1024 GLOBAL_BATCH_SIZE: int = 1024
MICRO_BATCH_SIZE: int = 1 MICRO_BATCH_SIZE: int = 1
...@@ -26,8 +27,8 @@ class MicrobatchCalculatorTest(DistributedTestBase): ...@@ -26,8 +27,8 @@ class MicrobatchCalculatorTest(DistributedTestBase):
def _test(self, rampup_batch_size: Optional[List[int]]) -> None: def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
for data_parallel_size in range(1, self.world_size + 1): for data_parallel_size in range(1, self.world_size + 1):
expected_global_batch_size = MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE expected_global_batch_size = self.GLOBAL_BATCH_SIZE
expected_micro_batch_size = MicrobatchCalculatorTest.MICRO_BATCH_SIZE expected_micro_batch_size = self.MICRO_BATCH_SIZE
if rampup_batch_size: if rampup_batch_size:
expected_global_batch_size = rampup_batch_size[0] expected_global_batch_size = rampup_batch_size[0]
num_consumed_samples = 0 num_consumed_samples = 0
...@@ -48,8 +49,8 @@ class MicrobatchCalculatorTest(DistributedTestBase): ...@@ -48,8 +49,8 @@ class MicrobatchCalculatorTest(DistributedTestBase):
_reconfigure_microbatch_calculator( _reconfigure_microbatch_calculator(
self.rank, self.rank,
rampup_batch_size, rampup_batch_size,
MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE, self.GLOBAL_BATCH_SIZE,
MicrobatchCalculatorTest.MICRO_BATCH_SIZE, self.MICRO_BATCH_SIZE,
data_parallel_size, data_parallel_size,
) )
...@@ -66,7 +67,7 @@ class MicrobatchCalculatorTest(DistributedTestBase): ...@@ -66,7 +67,7 @@ class MicrobatchCalculatorTest(DistributedTestBase):
current_global_batch_size = get_current_global_batch_size() current_global_batch_size = get_current_global_batch_size()
update_num_microbatches(current_global_batch_size) update_num_microbatches(current_global_batch_size)
current_global_batch_size = get_current_global_batch_size() current_global_batch_size = get_current_global_batch_size()
self.assertEqual(get_current_global_batch_size(), MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE) self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
def test_constant_microbatch_calculator(self): def test_constant_microbatch_calculator(self):
...@@ -76,5 +77,9 @@ class MicrobatchCalculatorTest(DistributedTestBase): ...@@ -76,5 +77,9 @@ class MicrobatchCalculatorTest(DistributedTestBase):
self._test(rampup_batch_size=[256, 128, 500]) self._test(rampup_batch_size=[256, 128, 500])
class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass
class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
import logging
import unittest
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.DEBUG)
# [P2P Ops Involved in Pipeline Model Parallel forward/backward]
# **forward_backward_pipelining_without_interleaving**
# - send_forward / recv_forward
# - send_backward / recv_backward
# - send_forward_recv_backward
# - send_backward_recv_forward
# **forward_backward_pipelining_with_interleaving**
# - send_backward_recv_backward
# - recv_backward
# - recv_forward
# - send_forward_backward_recv_forward_backward
# - send_forward_recv_forward
class P2PCommTestBase:
numel = 4
shape = (2, 2)
dtype = torch.float32
@property
def world_size(self):
return min(2, torch.cuda.device_count())
def _init_model_parallel(self):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=self.world_size,
virtual_pipeline_model_parallel_size_=None,
)
def create_tensor(self, value: int = None):
return torch.tensor(
[value] * self.numel).view(self.shape).to(device="cuda", dtype=self.dtype)
# Brief: Simulate warm-up.
# Brief: test `recv_forward` & `send_forward`.
def test_no_interleaving_warmup(self):
self.assertEqual(self.world_size, 2)
self._init_model_parallel()
input_tensor = None
if parallel_state.is_pipeline_first_stage():
tensor = self.create_tensor(self.rank)
print(tensor)
p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
else:
input_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
if parallel_state.is_pipeline_first_stage():
self.assertIsNone(input_tensor)
else:
expected_input_tensor = self.create_tensor(self.rank - 1)
self.assertEqual(input_tensor, expected_input_tensor)
# Brief: test `send_forward`, `send_forward_recv_forward`, and `recv_forward`.
def test_send_forward_recv_forward(self):
self._init_model_parallel()
prev_tensor = None
tensor = self.create_tensor(self.rank)
if parallel_state.is_pipeline_first_stage():
p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
elif parallel_state.is_pipeline_last_stage():
prev_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
else:
prev_tensor = p2p_communication.send_forward_recv_forward(
output_tensor=tensor,
recv_prev=True,
tensor_shape=self.shape,
dtype=self.dtype,
)
if parallel_state.is_pipeline_first_stage():
self.assertIsNone(prev_tensor)
else:
expected_prev_tensor = self.create_tensor(self.rank - 1)
self.assertEqual(prev_tensor, expected_prev_tensor)
# Brief: test `send_backward`, `send_backward_recv_backward`, and `recv_backward`.
def test_send_backward_recv_backward(self):
self._init_model_parallel()
tensor = self.create_tensor(self.rank)
next_tensor = None
if parallel_state.is_pipeline_first_stage():
next_tensor = p2p_communication.recv_backward(tensor_shape=self.shape, dtype=self.dtype)
elif parallel_state.is_pipeline_last_stage():
p2p_communication.send_backward(input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype)
else:
next_tensor = p2p_communication.send_backward_recv_backward(
input_tensor_grad=tensor,
recv_next=True,
tensor_shape=self.shape,
dtype=self.dtype,
)
if parallel_state.is_pipeline_last_stage():
self.assertIsNone(next_tensor)
else:
expected_next_tensor = self.create_tensor(self.rank + 1)
self.assertEqual(next_tensor, expected_next_tensor)
# n.b.(mkozuki): Intentionally skip NCCL backend tests as I trust pytorch/pytorch repo.
class UccP2PCommTest(P2PCommTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
...@@ -6,7 +6,8 @@ from torch.testing._internal import common_utils ...@@ -6,7 +6,8 @@ from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
...@@ -21,7 +22,7 @@ def calc_expected_tensor_model_paralell_rank( ...@@ -21,7 +22,7 @@ def calc_expected_tensor_model_paralell_rank(
return rank % tensor_model_parallel_world_size return rank % tensor_model_parallel_world_size
class ParallelStateTest(DistributedTestBase): class ParallelStateTestBase:
def test_initialize_model_parallel(self) -> None: def test_initialize_model_parallel(self) -> None:
self.assertFalse(parallel_state.model_parallel_is_initialized()) self.assertFalse(parallel_state.model_parallel_is_initialized())
...@@ -122,5 +123,9 @@ class ParallelStateTest(DistributedTestBase): ...@@ -122,5 +123,9 @@ class ParallelStateTest(DistributedTestBase):
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass
class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
import logging import logging
import itertools import itertools
import re
from typing import Optional from typing import Optional
import unittest
import torch import torch
from torch.testing._internal import common_utils from torch.testing._internal import common_utils
...@@ -24,7 +26,10 @@ from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interl ...@@ -24,7 +26,10 @@ from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interl
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving, forward_backward_pipelining_without_interleaving,
) )
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
from apex.transformer.testing import commons as testing_utils from apex.transformer.testing import commons as testing_utils
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
...@@ -54,15 +59,16 @@ def get_target_loss(hidden_size: int, microbatch_size: int, parallel_model_world ...@@ -54,15 +59,16 @@ def get_target_loss(hidden_size: int, microbatch_size: int, parallel_model_world
return hidden_size * hidden_size * torch.sum(data).item() * microbatch_size / layers_per_rank return hidden_size * hidden_size * torch.sum(data).item() * microbatch_size / layers_per_rank
class PipelineParallelForwardBackwardTest(DistributedTestBase): class PipelineParallelForwardBackwardTestBase:
GLOBAL_BATCH_SIZE = 16 GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 2 MICRO_BATCH_SIZE = 2
HIDDEN_SIZE = 32 HIDDEN_SIZE = 32
@property deallocate_options = (True, False)
def world_size(self) -> int: # If :obj:`None`, (torch.float32, torch.float16, torch.bfloat16) are dtype options on Ampere.
return min(torch.cuda.device_count(), 8) # You can limit the options by overriding the following `dtypes`.
dtypes = None
def _forward_backward_test_impl( def _forward_backward_test_impl(
self, self,
...@@ -71,9 +77,13 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -71,9 +77,13 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
pipeline_model_parallel_world_size: Optional[int], pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False, async_comm: bool = False,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None: ) -> None:
dtype_options = self.dtypes or [torch.float32] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product( for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False), dtype_options, self.deallocate_options,
): ):
grad_scaler = ( grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0) torch.cuda.amp.GradScaler(init_scale=4.0)
...@@ -92,29 +102,32 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -92,29 +102,32 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
tensor_model_parallel_size_=tensor_model_parallel_world_size, tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size, pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
default_backend=default_backend,
p2p_backend=p2p_backend,
) )
pp_utils._reconfigure_microbatch_calculator( pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(), rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None, rampup_batch_size=None,
global_batch_size=PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE, global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE, micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(), data_parallel_size=parallel_state.get_data_parallel_world_size(),
) )
global_batch_shape = ( global_batch_shape = (
PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE self.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(), // parallel_state.get_data_parallel_world_size(),
PipelineParallelForwardBackwardTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
) )
batch =(((self.rank + 1) * torch.ones(global_batch_shape)).cuda(), ) batch =(((self.rank + 1) * torch.ones(global_batch_shape)).cuda(), )
model = build_model( model = build_model(
testing_utils.model_provider_func, testing_utils.model_provider_func,
wrap_with_ddp=True, # Use DDP only when it's better to have
wrap_with_ddp=data_parallel_size > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE, hidden_size=self.HIDDEN_SIZE,
) )
for model_module in model: for model_module in model:
...@@ -132,9 +145,9 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -132,9 +145,9 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
forward_only=forward_only, forward_only=forward_only,
# `tensor_shape` is the shape of micro batch. # `tensor_shape` is the shape of micro batch.
tensor_shape=( tensor_shape=(
PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE, self.MICRO_BATCH_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE, self.HIDDEN_SIZE,
), ),
dtype=dtype, dtype=dtype,
async_comm=async_comm, async_comm=async_comm,
...@@ -143,8 +156,8 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -143,8 +156,8 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
) )
if dtype == torch.float32: if dtype == torch.float32:
hidden_size = PipelineParallelForwardBackwardTest.HIDDEN_SIZE hidden_size = self.HIDDEN_SIZE
microbatch_size = PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE microbatch_size = self.MICRO_BATCH_SIZE
target_loss = get_target_loss(hidden_size, microbatch_size, pipeline_model_parallel_world_size, self.world_size) target_loss = get_target_loss(hidden_size, microbatch_size, pipeline_model_parallel_world_size, self.world_size)
for loss_item in loss: for loss_item in loss:
...@@ -166,7 +179,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -166,7 +179,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
def test_no_pipelining_inference(self): def test_no_pipelining_inference(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None) self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining_default(self): def test_pipelining_without_interleaving(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None False, forward_backward_pipelining_without_interleaving, None, None
) )
...@@ -176,7 +189,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -176,7 +189,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
) )
def test_pipelining_inference(self): def test_pipelining_without_interleaving_inference(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None True, forward_backward_pipelining_without_interleaving, None, None
) )
...@@ -197,5 +210,46 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -197,5 +210,46 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
) )
class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _run_hybrid_distributed_backend(self, forward_only: bool) -> None:
self._forward_backward_test_impl(
forward_only, forward_backward_pipelining_without_interleaving, None, None,
default_backend="nccl", p2p_backend="ucc",
)
@unittest.skipUnless(HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, "Needs driver >= 470.42.01")
def _test_hybrid_backends(self, forward_only: bool) -> None:
if HAS_TORCH_UCC:
self._run_hybrid_distributed_backend(forward_only)
else:
with self.assertRaisesRegex(
ImportError,
re.escape("UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"),
):
self._run_hybrid_distributed_backend(forward_only)
def test_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(False)
def test_pipelining_without_interleaving_inference_ucc_for_p2p(self):
self._test_hybrid_backends(True)
# n.b.(mkozuki): pipeline parallel w/o interleaving with UCX_TLS=tcp,sm fails.
class UccPipelineParallelForwardBackwardTest(UccDistributedTestBase, PipelineParallelForwardBackwardTestBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
deallocate_options = (False,)
dtypes = (torch.float32,)
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING) ...@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerRandomTest(DistributedTestBase): class TransformerRandomTestBase:
def test_set_cuda_rng_state(self): def test_set_cuda_rng_state(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1): for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size: if self.world_size % tensor_model_parallel_world_size:
...@@ -111,5 +112,9 @@ class TransformerRandomTest(DistributedTestBase): ...@@ -111,5 +112,9 @@ class TransformerRandomTest(DistributedTestBase):
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass
class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -63,7 +63,9 @@ def run_transformer_tests(): ...@@ -63,7 +63,9 @@ def run_transformer_tests():
import torch import torch
num_devices = torch.cuda.device_count() num_devices = torch.cuda.device_count()
test_run_cmd += f" --pipeline-model-parallel-size {num_devices}" tensor_model_parallel_size = 1 + (1 - (num_devices % 2 and num_devices > 4))
pipeline_model_parallel_size = num_devices // tensor_model_parallel_size
test_run_cmd += f" --pipeline-model-parallel-size {pipeline_model_parallel_size} --tensor-model-parallel-size {tensor_model_parallel_size}"
else: else:
test_run_cmd += f" --use-cpu-initialization" test_run_cmd += f" --use-cpu-initialization"
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}") print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
......
...@@ -7,12 +7,12 @@ logging.getLogger("torch").setLevel(logging.WARNING) ...@@ -7,12 +7,12 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import utils from apex.transformer.tensor_parallel import utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerUtilsTest(DistributedTestBase): class TransformerUtilsTest(NcclDistributedTestBase):
def test_split_tensor_along_last_dim(self): def test_split_tensor_along_last_dim(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1): for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0: if self.world_size % tensor_model_paralell_world_size > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment