Unverified Commit 3490b9e1 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Allow for different backend for Pipeline Parallel ProcessGroups (#1380)



* NcclDistributedTestBase

* fix stupid mistake

* add UCC test

* add UCC backend

* torch ucc tests

* allows for UCC backend

* Set `UCX_TLS` to `tcp,cuda_copy` & Use DDP iff it makes sense

* Apply 4 suggestion(s) to 1 file(s)

* mix&match NCCL & UCC

* use both ucc&nccl in gpt

* UCC for Pipeline Parallel, NCCL for the others

* conditionally use ucc

* make ucc guards more friendly

* test raises when torch_ucc isn't available

* Change to member variable from class variable
Co-authored-by: default avatarAidyn Aitzhan <31858918+Aidyn-A@users.noreply.github.com>

* pass async_comm to train, I mistakenly dropped it during the rebase

* fix typo: functionality

* Enable tensor parallel only when device count > 4

I want pipeline model parallel world size to be >= 4 because
previously I saw GPT/BERT failing when only UCC is used.
So I'm speculating that there's some gotcha around pipeline size of 4.

* Add nvidia driver version guard
Co-authored-by: default avatarAidyn Aitzhan <31858918+Aidyn-A@users.noreply.github.com>

* move world_size as it was not correctly reflected

* keep eye on the nvml api thing

* import unittest
Co-authored-by: default avatarAidyn Aitzhan <31858918+Aidyn-A@users.noreply.github.com>
parent d36397d2
......@@ -16,6 +16,7 @@
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups."""
from typing import Tuple, Optional
import warnings
import torch
......@@ -75,6 +76,9 @@ def initialize_model_parallel(
pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
"""
Initialize model data parallel groups.
......@@ -84,6 +88,15 @@ def initialize_model_parallel(
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Keyword Arguments:
default_backend: Backend of process groups except for pipeline parallel ones.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
p2p_backend: Backend of process groups for pipeline model parallel.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
.. note::
`torch_ucc <https://github.com/facebookresearch/torch_ucc>`_ is
necessary for "ucc" backend.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
......@@ -103,6 +116,14 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
assert default_backend is None or default_backend in ("nccl", "ucc")
assert p2p_backend is None or p2p_backend in ("nccl", "ucc")
if "ucc" in (default_backend, p2p_backend):
check_torch_ucc_availability()
warnings.warn("`ucc` backend support is experimental", ExperimentalWarning)
if default_backend == "ucc":
warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning)
world_size: int = torch.distributed.get_world_size()
tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size)
......@@ -160,7 +181,7 @@ def initialize_model_parallel(
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
......@@ -172,7 +193,7 @@ def initialize_model_parallel(
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks
]
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
......@@ -185,7 +206,7 @@ def initialize_model_parallel(
ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
)
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=default_backend)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
......@@ -206,7 +227,7 @@ def initialize_model_parallel(
), "position embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=p2p_backend)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
......@@ -234,13 +255,13 @@ def initialize_model_parallel(
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
group = torch.distributed.new_group(embedding_ranks, backend=default_backend)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
group = torch.distributed.new_group(position_embedding_ranks, backend=default_backend)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
......@@ -578,3 +599,16 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# Used to warn when the UCC is specified.
class ExperimentalWarning(Warning): pass
def check_torch_ucc_availability() -> None:
try:
import torch_ucc # NOQA
except ImportError:
raise ImportError(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import random
from typing import Optional, Union, List
......@@ -117,6 +118,10 @@ def initialize_distributed(backend="nccl"):
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
if backend not in ("nccl", "ucc"):
raise RuntimeError(f"Currently only nccl & ucc are supported but {backend}")
if backend == "ucc":
import torch_ucc # NOQA
args = global_vars.get_args()
local_rank = args.local_rank
......@@ -141,7 +146,8 @@ def initialize_distributed(backend="nccl"):
master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend=backend, world_size=world_size, rank=rank, init_method=init_method
backend=backend, world_size=world_size, rank=rank, init_method=init_method,
timeout=datetime.timedelta(seconds=60),
)
......
import os
import sys
import unittest
import torch
from torch import distributed as dist
from torch.utils import collect_env
from torch.testing._internal import common_utils
from torch.testing._internal import common_distributed
HAS_TORCH_UCC = None
try:
import torch_ucc
HAS_TORCH_UCC = True
except ImportError:
HAS_TORCH_UCC = False
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = "470.42.01"
_driver_version = None
if torch.cuda.is_available():
_driver_version = collect_env.get_nvidia_driver_version(collect_env.run)
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
class DistributedTestBase(common_distributed.MultiProcessTestCase):
BACKEND_NCCL = "nccl"
class DistributedTestBase(common_distributed.MultiProcessTestCase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def setUp(self) -> None:
super().setUp()
self._setup_pre_spawn()
self._spawn_processes()
def tearDown(self) -> None:
......@@ -32,6 +48,7 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase):
def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name)
self.assertTrue(torch.cuda.is_available())
self.assertTrue(hasattr(self, "DISTRIBUTED_BACKEND"))
self.rank = rank
self.file_name = file_name
......@@ -40,13 +57,13 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase):
try:
dist.init_process_group(
init_method=self.init_method,
backend=DistributedTestBase.BACKEND_NCCL,
backend=self.DISTRIBUTED_BACKEND,
world_size=int(self.world_size),
rank=self.rank,
)
except RuntimeError as e:
if "recompile" in e.args[0]:
print(f"Backend of {DistributedTestBase.BACKEND_NCCL} not available")
print(f"Backend of {self.DISTRIBUTED_BACKEND} not available")
sys.exit(0)
raise
......@@ -58,3 +75,55 @@ class DistributedTestBase(common_distributed.MultiProcessTestCase):
dist.destroy_process_group()
sys.exit(0)
def _setup_pre_spawn(self):
pass
class NcclDistributedTestBase(DistributedTestBase):
DISTRIBUTED_BACKEND = "nccl"
@unittest.skipUnless(
HAS_TORCH_UCC,
"Requires [`torch_ucc`](https://github.com/facebookresearch/torch_ucc)",
)
@unittest.skipUnless(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
"See https://github.com/openucx/ucc/issues/496",
)
class UccDistributedTestBase(DistributedTestBase):
DISTRIBUTED_BACKEND = "ucc"
def _setup_pre_spawn(self) -> None:
self.master_addr = "localhost"
os.environ["MASTER_ADDR"] = "localhost"
self._has_master_port = "MASTER_PORT" in os.environ
if self._has_master_port:
self.master_port = os.environ["MASTER_PORT"]
else:
try:
from caffe2.torch.fb.common.utils import get_free_port
self.master_port = str(get_free_port())
except ImportError:
self.master_port = "12375"
os.environ["MASTER_PORT"] = self.master_port
self._has_ucx_tls = "UCX_TLS" in os.environ
if not self._has_ucx_tls:
os.environ["UCX_TLS"] = "tcp,cuda_copy"
print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ["UCX_TLS"]))
def tearDown(self) -> None:
super().tearDown()
if not self._has_master_port:
del os.environ["MASTER_PORT"]
if not self._has_ucx_tls:
del os.environ["UCX_TLS"]
@property
def init_method(self):
return "tcp://localhost:" + os.environ["MASTER_PORT"]
import random
import torch
try:
import torch_ucc
except ImportError:
HAS_TORCH_UCC = False
else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.log_util import set_logging_level
from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import (
......@@ -27,6 +35,7 @@ class DebugWarning(Warning):
pass
set_logging_level("WARNING")
mode = None
MANUAL_SEED = 42
inds = None
......@@ -154,7 +163,7 @@ if __name__ == "__main__":
effective_length = fancy_data.size(0) // global_vars.get_args().seq_length
effective_length = fancy_data.size(0) - global_vars.get_args().seq_length
initialize_distributed()
initialize_distributed("nccl")
world_size = torch.distributed.get_world_size()
failure = None
init = True
......@@ -182,6 +191,8 @@ if __name__ == "__main__":
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
default_backend="nccl",
p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
......
from functools import partial
from typing import List
import time
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
MANUAL_SEED = 42
inds = None
data_idx = 0
N_VOCAB = 128
def download_fancy_data():
# import requests
# response = requests.get('https://internet.com/book.txt')
# text = ' '.join(response.text.split())
text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
"""
text = text * 1024
encoded = text.encode("ascii", "replace")
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
# build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx
global inds
global MANUAL_SEED
temps = list()
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
model_parallel_cuda_manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device="cuda")
MANUAL_SEED += 1
data_idx = 0
data_idx_ = data_idx
offset = inds[data_idx_]
data_idx += 1
curr = fancy_data[offset : offset + sequence_len + 1].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
return temp
easy_data = None
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
tokens_ = data.long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
N_VOCAB, # tokenizer.eod,
False, # args.reset_position_ids,
False, # args.reset_attention_mask,
False, # args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
def fwd_step_func(batch, model):
"""Forward step."""
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train(model, optim, pipeline_model_parallel_size, async_comm):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size
fwd_bwd_func = forward_backward_pipelining_without_interleaving
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
runtime = 0
# training loop
for i in range(3):
since = time.time()
if torch.distributed.get_rank() == 0:
print("begin iter", i)
batch = [
generate_fancy_data_labels(args.seq_length, args.global_batch_size)
for _ in range(pipeline_model_parallel_size)
]
if torch.distributed.get_rank() == 0:
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm
)
if torch.distributed.get_rank() == 0:
print("finished forward step")
optim.step()
if torch.distributed.get_rank() == 0:
print("finished iter", i)
runtime += time.time() - since
return runtime / 3.0
if __name__ == "__main__":
init = True
for async_comm in (False, True):
global fancy_data
global effective_length
if init:
init = False
global_vars.set_global_variables()
args = global_vars.get_args()
fancy_data = download_fancy_data()
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
args.padded_vocab_size = 128
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size, # args.data_parallel_size,
)
world_size = torch.distributed.get_world_size()
print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm)
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
print("Average Iteration Time:", runtime)
from functools import partial
from typing import List
import time
import torch
try:
import torch_ucc
except ImportError:
HAS_TORCH_UCC = False
else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
MANUAL_SEED = 42
inds = None
data_idx = 0
N_VOCAB = 128
def download_fancy_data():
# import requests
# response = requests.get('https://internet.com/book.txt')
# text = ' '.join(response.text.split())
text = """
An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
"""
text = text * 1024
encoded = text.encode("ascii", "replace")
ints = [int(encoded[i]) for i in range(len(encoded))]
return torch.tensor(ints)
# build a batch given sequence_len and batch size
def generate_fancy_data_labels(sequence_len, batch_size):
global data_idx
global inds
global MANUAL_SEED
temps = list()
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
model_parallel_cuda_manual_seed(MANUAL_SEED)
inds = torch.randperm(effective_length, device="cuda")
MANUAL_SEED += 1
data_idx = 0
data_idx_ = data_idx
offset = inds[data_idx_]
data_idx += 1
curr = fancy_data[offset : offset + sequence_len + 1].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
return temp
easy_data = None
def get_batch(int_tensors: List[torch.Tensor]):
data = int_tensors[0]
# Unpack.
tokens_ = data.long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
N_VOCAB, # tokenizer.eod,
False, # args.reset_position_ids,
False, # args.reset_attention_mask,
False, # args.eod_mask_loss,
)
return tokens, labels, loss_mask, attention_mask, position_ids
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
def fwd_step_func(batch, model):
"""Forward step."""
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train(model, optim, pipeline_model_parallel_size, async_comm):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size
fwd_bwd_func = forward_backward_pipelining_without_interleaving
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
runtime = 0
# training loop
for i in range(3):
since = time.time()
if torch.distributed.get_rank() == 0:
print("begin iter", i)
batch = [
generate_fancy_data_labels(args.seq_length, args.global_batch_size)
for _ in range(pipeline_model_parallel_size)
]
if torch.distributed.get_rank() == 0:
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm
)
if torch.distributed.get_rank() == 0:
print("finished forward step")
optim.step()
if torch.distributed.get_rank() == 0:
print("finished iter", i)
runtime += time.time() - since
return runtime / 3.0
if __name__ == "__main__":
init = True
for async_comm in (False, True):
global fancy_data
global effective_length
if init:
init = False
global_vars.set_global_variables()
fancy_data = download_fancy_data()
args = global_vars.get_args()
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
initialize_distributed("nccl")
world_size = torch.distributed.get_world_size()
failure = None
args.padded_vocab_size = 128
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size, # args.data_parallel_size,
)
world_size = torch.distributed.get_world_size()
print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
default_backend="nccl",
p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm)
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
print("Average Iteration Time:", runtime)
......@@ -11,7 +11,8 @@ from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel import cross_entropy
from apex.transformer.testing.commons import set_random_seed, IdentityLayer
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
......@@ -54,7 +55,7 @@ def tensor_sharded_cross_entropy(
return loss, identity.weight.grad
class VocabParallelCrossEntropy(DistributedTestBase):
class VocabParallelCrossEntropyTestBase:
def test_cross_entropy(self):
batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
logits_scale = 1000.0
......@@ -85,5 +86,9 @@ class VocabParallelCrossEntropy(DistributedTestBase):
parallel_state.destroy_model_parallel()
class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): pass
class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
class BroadcastDataTest(DistributedTestBase):
class BroadcastDataTestBase:
def test_broadcast_data(self):
tensor_model_parallel_world_size: int = self.world_size // (
1 + self.world_size > 1
......@@ -55,5 +56,9 @@ class BroadcastDataTest(DistributedTestBase):
parallel_state.destroy_model_parallel()
class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass
class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -9,7 +9,8 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
......@@ -20,7 +21,7 @@ logging.getLogger("apex").setLevel(logging.WARNING)
torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTest(DistributedTestBase):
class TensorParallelLayerTestBase:
BATCH_SIZE: int = 17
SEQUENCE_LENGTH: int = 23
......@@ -40,29 +41,29 @@ class TensorParallelLayerTest(DistributedTestBase):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(TensorParallelLayerTest.SEED + 1)
set_random_seed(self.SEED + 1)
input_tensor = torch.randint(
0,
TensorParallelLayerTest.VOCAB_SIZE,
self.VOCAB_SIZE,
(
TensorParallelLayerTest.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH,
self.BATCH_SIZE,
self.SEQUENCE_LENGTH,
),
device="cuda",
)
loss_weight = torch.randn(
(
TensorParallelLayerTest.BATCH_SIZE,
TensorParallelLayerTest.SEQUENCE_LENGTH,
TensorParallelLayerTest.HIDDEN_SIZE,
self.BATCH_SIZE,
self.SEQUENCE_LENGTH,
self.HIDDEN_SIZE,
),
device="cuda",
)
set_random_seed(TensorParallelLayerTest.SEED)
set_random_seed(self.SEED)
embedding_torch = nn.Embedding(
TensorParallelLayerTest.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE,
self.VOCAB_SIZE,
self.HIDDEN_SIZE,
).cuda()
output_torch = embedding_torch(input_tensor)
loss_torch = torch.mul(output_torch, loss_weight).sum()
......@@ -71,10 +72,10 @@ class TensorParallelLayerTest(DistributedTestBase):
# N.B. (mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`.
set_random_seed(TensorParallelLayerTest.SEED)
set_random_seed(self.SEED)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
TensorParallelLayerTest.VOCAB_SIZE,
TensorParallelLayerTest.HIDDEN_SIZE,
self.VOCAB_SIZE,
self.HIDDEN_SIZE,
init_method=nn.init.normal_,
use_cpu_initialization=True,
).cuda()
......@@ -89,7 +90,7 @@ class TensorParallelLayerTest(DistributedTestBase):
splitted_weight_torch = torch.split(
embedding_torch.weight.grad,
TensorParallelLayerTest.VOCAB_SIZE
self.VOCAB_SIZE
// tensor_model_parallel_world_size,
0,
)[parallel_state.get_tensor_model_parallel_rank()]
......@@ -112,21 +113,21 @@ class TensorParallelLayerTest(DistributedTestBase):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
weight_shape = (
(TensorParallelLayerTest.OUTPUT_SIZE_COEFF, input_size)
(self.OUTPUT_SIZE_COEFF, input_size)
if is_column_parallel
else (output_size, TensorParallelLayerTest.INPUT_SIZE_COEFF)
else (output_size, self.INPUT_SIZE_COEFF)
)
weight = torch.empty(weight_shape)
set_random_seed(TensorParallelLayerTest.SEED)
set_random_seed(self.SEED)
sharding_dim_size = (
TensorParallelLayerTest.OUTPUT_SIZE_COEFF
self.OUTPUT_SIZE_COEFF
if is_column_parallel
else TensorParallelLayerTest.INPUT_SIZE_COEFF
else self.INPUT_SIZE_COEFF
)
if init_device == "cpu":
......@@ -144,7 +145,7 @@ class TensorParallelLayerTest(DistributedTestBase):
weight, torch.nn.init.normal_, dim
)
# Target
set_random_seed(TensorParallelLayerTest.SEED)
set_random_seed(self.SEED)
if init_device == "cpu":
main_weight = torch.empty(output_size, input_size)
nn.init.normal_(main_weight)
......@@ -180,10 +181,10 @@ class TensorParallelLayerTest(DistributedTestBase):
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = TensorParallelLayerTest.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = TensorParallelLayerTest.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
set_random_seed(TensorParallelLayerTest.SEED)
set_random_seed(self.SEED)
linear_layer = layers.RowParallelLinear(
input_size,
output_size,
......@@ -192,12 +193,12 @@ class TensorParallelLayerTest(DistributedTestBase):
use_cpu_initialization=True,
).cuda()
loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, output_size)
(self.BATCH_SIZE, output_size)
).cuda()
# Forward and backward
input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE, input_size, requires_grad=True
self.BATCH_SIZE, input_size, requires_grad=True
).cuda()
input_tensor.retain_grad()
output, _ = linear_layer(input_tensor)
......@@ -211,13 +212,13 @@ class TensorParallelLayerTest(DistributedTestBase):
a = linear_layer.master_weight.cuda()
dlda = torch.matmul(dldy.t(), x)
dldb = torch.matmul(
torch.ones(TensorParallelLayerTest.BATCH_SIZE, 1).cuda().t(), dldy
torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy
).view(-1)
dldx = torch.matmul(dldy, a)
with torch.no_grad():
curr_dlda = torch.split(
dlda, TensorParallelLayerTest.INPUT_SIZE_COEFF, dim=1
dlda, self.INPUT_SIZE_COEFF, dim=1
)[parallel_state.get_tensor_model_parallel_rank()].clone()
self.assertEqual(linear_layer.weight.grad, curr_dlda)
self.assertEqual(input_tensor.grad, dldx)
......@@ -240,9 +241,6 @@ class TensorParallelLayerTest(DistributedTestBase):
gradient_accumulation_fusion: bool,
):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
print(
f"tensor_model_parallel_world_size={tensor_model_parallel_world_size}"
)
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
......@@ -252,13 +250,13 @@ class TensorParallelLayerTest(DistributedTestBase):
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
feature_size_coeff = TensorParallelLayerTest.INPUT_SIZE_COEFF
feature_size_coeff = self.INPUT_SIZE_COEFF
feature_size = feature_size_coeff * tensor_model_parallel_world_size
hidden_size = feature_size
set_random_seed(TensorParallelLayerTest.SEED)
set_random_seed(self.SEED)
input_tensor = torch.randn(
TensorParallelLayerTest.BATCH_SIZE,
self.BATCH_SIZE,
hidden_size,
feature_size,
device="cuda",
......@@ -266,7 +264,7 @@ class TensorParallelLayerTest(DistributedTestBase):
)
input_tensor.retain_grad()
loss_weight = torch.randn(
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,),
(self.BATCH_SIZE, hidden_size, feature_size,),
device="cuda",
)
linear = layers.ColumnParallelLinear(
......@@ -285,7 +283,7 @@ class TensorParallelLayerTest(DistributedTestBase):
output, _ = linear(input_tensor)
self.assertEqual(
output.shape,
(TensorParallelLayerTest.BATCH_SIZE, hidden_size, feature_size,),
(self.BATCH_SIZE, hidden_size, feature_size,),
)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
......@@ -296,7 +294,7 @@ class TensorParallelLayerTest(DistributedTestBase):
a = linear.master_weight.cuda().clone()
dldx = torch.matmul(dldy, a)
self.assertEqual(input_tensor.grad, dldx)
# TODO (mkozuki): Cover the other cases.
# TODO(mkozuki): Cover the other cases.
if (
tensor_model_parallel_world_size == 1
and not gradient_accumulation_fusion
......@@ -310,5 +308,13 @@ class TensorParallelLayerTest(DistributedTestBase):
parallel_state.destroy_model_parallel()
class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase):
pass
class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase):
pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class MappingTest(DistributedTestBase):
class MappingTestBase:
def test_reduce(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
......@@ -80,5 +81,9 @@ class MappingTest(DistributedTestBase):
parallel_state.destroy_model_parallel()
class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -13,12 +13,13 @@ from apex.transformer.pipeline_parallel.utils import (
get_current_global_batch_size,
update_num_microbatches,
)
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class MicrobatchCalculatorTest(DistributedTestBase):
class MicrobatchCalculatorTestBase:
GLOBAL_BATCH_SIZE: int = 1024
MICRO_BATCH_SIZE: int = 1
......@@ -26,8 +27,8 @@ class MicrobatchCalculatorTest(DistributedTestBase):
def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
for data_parallel_size in range(1, self.world_size + 1):
expected_global_batch_size = MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE
expected_micro_batch_size = MicrobatchCalculatorTest.MICRO_BATCH_SIZE
expected_global_batch_size = self.GLOBAL_BATCH_SIZE
expected_micro_batch_size = self.MICRO_BATCH_SIZE
if rampup_batch_size:
expected_global_batch_size = rampup_batch_size[0]
num_consumed_samples = 0
......@@ -48,8 +49,8 @@ class MicrobatchCalculatorTest(DistributedTestBase):
_reconfigure_microbatch_calculator(
self.rank,
rampup_batch_size,
MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE,
MicrobatchCalculatorTest.MICRO_BATCH_SIZE,
self.GLOBAL_BATCH_SIZE,
self.MICRO_BATCH_SIZE,
data_parallel_size,
)
......@@ -66,7 +67,7 @@ class MicrobatchCalculatorTest(DistributedTestBase):
current_global_batch_size = get_current_global_batch_size()
update_num_microbatches(current_global_batch_size)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(get_current_global_batch_size(), MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE)
self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE)
parallel_state.destroy_model_parallel()
def test_constant_microbatch_calculator(self):
......@@ -76,5 +77,9 @@ class MicrobatchCalculatorTest(DistributedTestBase):
self._test(rampup_batch_size=[256, 128, 500])
class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass
class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import unittest
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.DEBUG)
# [P2P Ops Involved in Pipeline Model Parallel forward/backward]
# **forward_backward_pipelining_without_interleaving**
# - send_forward / recv_forward
# - send_backward / recv_backward
# - send_forward_recv_backward
# - send_backward_recv_forward
# **forward_backward_pipelining_with_interleaving**
# - send_backward_recv_backward
# - recv_backward
# - recv_forward
# - send_forward_backward_recv_forward_backward
# - send_forward_recv_forward
class P2PCommTestBase:
numel = 4
shape = (2, 2)
dtype = torch.float32
@property
def world_size(self):
return min(2, torch.cuda.device_count())
def _init_model_parallel(self):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=self.world_size,
virtual_pipeline_model_parallel_size_=None,
)
def create_tensor(self, value: int = None):
return torch.tensor(
[value] * self.numel).view(self.shape).to(device="cuda", dtype=self.dtype)
# Brief: Simulate warm-up.
# Brief: test `recv_forward` & `send_forward`.
def test_no_interleaving_warmup(self):
self.assertEqual(self.world_size, 2)
self._init_model_parallel()
input_tensor = None
if parallel_state.is_pipeline_first_stage():
tensor = self.create_tensor(self.rank)
print(tensor)
p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
else:
input_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
if parallel_state.is_pipeline_first_stage():
self.assertIsNone(input_tensor)
else:
expected_input_tensor = self.create_tensor(self.rank - 1)
self.assertEqual(input_tensor, expected_input_tensor)
# Brief: test `send_forward`, `send_forward_recv_forward`, and `recv_forward`.
def test_send_forward_recv_forward(self):
self._init_model_parallel()
prev_tensor = None
tensor = self.create_tensor(self.rank)
if parallel_state.is_pipeline_first_stage():
p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
elif parallel_state.is_pipeline_last_stage():
prev_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
else:
prev_tensor = p2p_communication.send_forward_recv_forward(
output_tensor=tensor,
recv_prev=True,
tensor_shape=self.shape,
dtype=self.dtype,
)
if parallel_state.is_pipeline_first_stage():
self.assertIsNone(prev_tensor)
else:
expected_prev_tensor = self.create_tensor(self.rank - 1)
self.assertEqual(prev_tensor, expected_prev_tensor)
# Brief: test `send_backward`, `send_backward_recv_backward`, and `recv_backward`.
def test_send_backward_recv_backward(self):
self._init_model_parallel()
tensor = self.create_tensor(self.rank)
next_tensor = None
if parallel_state.is_pipeline_first_stage():
next_tensor = p2p_communication.recv_backward(tensor_shape=self.shape, dtype=self.dtype)
elif parallel_state.is_pipeline_last_stage():
p2p_communication.send_backward(input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype)
else:
next_tensor = p2p_communication.send_backward_recv_backward(
input_tensor_grad=tensor,
recv_next=True,
tensor_shape=self.shape,
dtype=self.dtype,
)
if parallel_state.is_pipeline_last_stage():
self.assertIsNone(next_tensor)
else:
expected_next_tensor = self.create_tensor(self.rank + 1)
self.assertEqual(next_tensor, expected_next_tensor)
# n.b.(mkozuki): Intentionally skip NCCL backend tests as I trust pytorch/pytorch repo.
class UccP2PCommTest(P2PCommTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -6,7 +6,8 @@ from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
......@@ -21,7 +22,7 @@ def calc_expected_tensor_model_paralell_rank(
return rank % tensor_model_parallel_world_size
class ParallelStateTest(DistributedTestBase):
class ParallelStateTestBase:
def test_initialize_model_parallel(self) -> None:
self.assertFalse(parallel_state.model_parallel_is_initialized())
......@@ -122,5 +123,9 @@ class ParallelStateTest(DistributedTestBase):
parallel_state.destroy_model_parallel()
class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass
class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
import logging
import itertools
import re
from typing import Optional
import unittest
import torch
from torch.testing._internal import common_utils
......@@ -24,7 +26,10 @@ from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interl
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
from apex.transformer.testing import commons as testing_utils
logging.getLogger("apex").setLevel(logging.WARNING)
......@@ -39,7 +44,7 @@ def init_weights(m):
m.bias.fill_(1.0)
def get_target_loss(hidden_size: int, microbatch_size: int, parallel_model_world_size: int, world_size: int) -> float:
def get_target_loss(hidden_size: int, microbatch_size: int, parallel_model_world_size: int, world_size: int) -> float:
layers_per_rank = world_size // parallel_model_world_size
data = torch.arange(start = 0, end = layers_per_rank, dtype = torch.int) + 1
......@@ -54,15 +59,16 @@ def get_target_loss(hidden_size: int, microbatch_size: int, parallel_model_world
return hidden_size * hidden_size * torch.sum(data).item() * microbatch_size / layers_per_rank
class PipelineParallelForwardBackwardTest(DistributedTestBase):
class PipelineParallelForwardBackwardTestBase:
GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 2
HIDDEN_SIZE = 32
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
deallocate_options = (True, False)
# If :obj:`None`, (torch.float32, torch.float16, torch.bfloat16) are dtype options on Ampere.
# You can limit the options by overriding the following `dtypes`.
dtypes = None
def _forward_backward_test_impl(
self,
......@@ -71,9 +77,13 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
dtype_options = self.dtypes or [torch.float32] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False),
dtype_options, self.deallocate_options,
):
grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0)
......@@ -92,29 +102,32 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
default_backend=default_backend,
p2p_backend=p2p_backend,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE,
micro_batch_size=PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE,
global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
global_batch_shape = (
PipelineParallelForwardBackwardTest.GLOBAL_BATCH_SIZE
self.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(),
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
)
batch =(((self.rank + 1) * torch.ones(global_batch_shape)).cuda(), )
model = build_model(
testing_utils.model_provider_func,
wrap_with_ddp=True,
# Use DDP only when it's better to have
wrap_with_ddp=data_parallel_size > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
hidden_size=self.HIDDEN_SIZE,
)
for model_module in model:
......@@ -132,9 +145,9 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
forward_only=forward_only,
# `tensor_shape` is the shape of micro batch.
tensor_shape=(
PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
),
dtype=dtype,
async_comm=async_comm,
......@@ -143,8 +156,8 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
)
if dtype == torch.float32:
hidden_size = PipelineParallelForwardBackwardTest.HIDDEN_SIZE
microbatch_size = PipelineParallelForwardBackwardTest.MICRO_BATCH_SIZE
hidden_size = self.HIDDEN_SIZE
microbatch_size = self.MICRO_BATCH_SIZE
target_loss = get_target_loss(hidden_size, microbatch_size, pipeline_model_parallel_world_size, self.world_size)
for loss_item in loss:
......@@ -166,7 +179,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
def test_no_pipelining_inference(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining_default(self):
def test_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
......@@ -176,7 +189,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_pipelining_inference(self):
def test_pipelining_without_interleaving_inference(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
......@@ -197,5 +210,46 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
)
class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _run_hybrid_distributed_backend(self, forward_only: bool) -> None:
self._forward_backward_test_impl(
forward_only, forward_backward_pipelining_without_interleaving, None, None,
default_backend="nccl", p2p_backend="ucc",
)
@unittest.skipUnless(HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, "Needs driver >= 470.42.01")
def _test_hybrid_backends(self, forward_only: bool) -> None:
if HAS_TORCH_UCC:
self._run_hybrid_distributed_backend(forward_only)
else:
with self.assertRaisesRegex(
ImportError,
re.escape("UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"),
):
self._run_hybrid_distributed_backend(forward_only)
def test_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(False)
def test_pipelining_without_interleaving_inference_ucc_for_p2p(self):
self._test_hybrid_backends(True)
# n.b.(mkozuki): pipeline parallel w/o interleaving with UCX_TLS=tcp,sm fails.
class UccPipelineParallelForwardBackwardTest(UccDistributedTestBase, PipelineParallelForwardBackwardTestBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
deallocate_options = (False,)
dtypes = (torch.float32,)
if __name__ == "__main__":
common_utils.run_tests()
......@@ -7,12 +7,13 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerRandomTest(DistributedTestBase):
class TransformerRandomTestBase:
def test_set_cuda_rng_state(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
......@@ -111,5 +112,9 @@ class TransformerRandomTest(DistributedTestBase):
parallel_state.destroy_model_parallel()
class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass
class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
......@@ -63,7 +63,9 @@ def run_transformer_tests():
import torch
num_devices = torch.cuda.device_count()
test_run_cmd += f" --pipeline-model-parallel-size {num_devices}"
tensor_model_parallel_size = 1 + (1 - (num_devices % 2 and num_devices > 4))
pipeline_model_parallel_size = num_devices // tensor_model_parallel_size
test_run_cmd += f" --pipeline-model-parallel-size {pipeline_model_parallel_size} --tensor-model-parallel-size {tensor_model_parallel_size}"
else:
test_run_cmd += f" --use-cpu-initialization"
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
......
......@@ -7,12 +7,12 @@ logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerUtilsTest(DistributedTestBase):
class TransformerUtilsTest(NcclDistributedTestBase):
def test_split_tensor_along_last_dim(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment