Unverified Commit a0ed4151 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Format & Test Refactoring (#1325)

* try PyTorch custom TestCase class

* revert

* initial working example

* update

* data utils

* fix imports

* hardcode backend to nccl

* fix signature

* fix typo

* mapping

* set device

* init

* refactor x entropy

* remove unused import & destroy model parallel

* refactor random

* fix test

* remove migrated tests

* refactor

* init

* separate affine weight init

* init model parallel

* split more

* weight init fix part 1

* use cpu init for consistency btwn native and tensor parallel

* black

* add col parallel

* use a 3D tensor of square matrix for column parallel linear

* skip the failing cases

* migrate layers test

* pipeline parallel forward/backward

* fix typo

* fix typo

* fix

* fix pipeline world size

* black

* rm `run_pipeline_parallel_test` in favor of test_pipeline_parallel_fwd_bwd.py

* stop logging

* set log level

* black

* license and format

* fix

* skip tf32 as matrices are small

* remove potentially inappropriate license

* Apply suggestions from code review

* remove `TODO` comment

* `torch.testing.assert_allclose` -> `torch.testing.assert_close`

* remove comment-outs

* remote unused import

* minor fix
parent f10b4b89
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerRandomTest(DistributedTestBase):
def test_set_cuda_rng_state(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
size, seed = 123, 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
rng_state = torch.cuda.get_rng_state()
rng_state_clone = rng_state.clone()
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)
self.assertGreater(
torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0
)
new_rng_state = torch.cuda.get_rng_state()
self.assertGreater(new_rng_state.sub(rng_state).max(), 0)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
torch.testing.assert_close(result_2, result_1)
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0)
parallel_state.destroy_model_parallel()
def test_cuda_rng_tracker(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
seed_1, seed_2, size = 1234, 4321, [12, 21]
tensor = torch.cuda.FloatTensor(size)
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
targt_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
self.assertEqual(target_11, result_11)
self.assertEqual(target_12, result_12)
self.assertEqual(targt_21, result_21)
self.assertEqual(target_22, result_22)
self.assertNotEqual(result_11, result_21)
self.assertNotEqual(result_21, result_22)
tensor_parallel.random.get_cuda_rng_tracker().reset()
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
...@@ -5,32 +5,26 @@ import sys ...@@ -5,32 +5,26 @@ import sys
import unittest import unittest
DENY_TEST = [
"megatron_gpt_pipeline",
]
MULTIGPU_TEST = [
"pipeline_parallel_test",
]
SEVERALGPU_TEST = [ SEVERALGPU_TEST = [
"bert_minimal_test", "bert_minimal_test",
"gpt_minimal_test", "gpt_minimal_test",
"dynamic_batchsize_test", "dynamic_batchsize_test",
] ]
def get_multigpu_launch_option(min_gpu): def get_multigpu_launch_option(min_gpu):
should_skip = False should_skip = False
import torch import torch
num_devices = torch.cuda.device_count() num_devices = torch.cuda.device_count()
if num_devices < min_gpu: if num_devices < min_gpu:
should_skip = True should_skip = True
distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}" distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}"
return should_skip, distributed_run_options return should_skip, distributed_run_options
def get_launch_option(test_filename) -> Tuple[bool, str]: def get_launch_option(test_filename) -> Tuple[bool, str]:
should_skip = False should_skip = False
for multigpu_test in MULTIGPU_TEST:
if multigpu_test in test_filename:
return get_multigpu_launch_option(2)
for severalgpu_test in SEVERALGPU_TEST: for severalgpu_test in SEVERALGPU_TEST:
if severalgpu_test in test_filename: if severalgpu_test in test_filename:
return get_multigpu_launch_option(3) return get_multigpu_launch_option(3)
...@@ -43,7 +37,8 @@ def run_transformer_tests(): ...@@ -43,7 +37,8 @@ def run_transformer_tests():
# directory = os.path.abspath(os.path.join(repository_root, "tests/mpu")) # directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
directory = os.path.dirname(__file__) directory = os.path.dirname(__file__)
files = [ files = [
os.path.join(directory, f) for f in os.listdir(directory) os.path.join(directory, f)
for f in os.listdir(directory)
if f.startswith("run_") and os.path.isfile(os.path.join(directory, f)) if f.startswith("run_") and os.path.isfile(os.path.join(directory, f))
] ]
print("#######################################################") print("#######################################################")
...@@ -53,36 +48,35 @@ def run_transformer_tests(): ...@@ -53,36 +48,35 @@ def run_transformer_tests():
errors = [] errors = []
for i, test_file in enumerate(files, 1): for i, test_file in enumerate(files, 1):
is_denied = False is_denied = False
for deny_file in DENY_TEST:
if deny_file in test_file:
is_denied = True
if is_denied:
print(f"### {i} / {len(files)}: {test_file} skipped")
continue
should_skip, launch_option = get_launch_option(test_file) should_skip, launch_option = get_launch_option(test_file)
if should_skip: if should_skip:
print(f"### {i} / {len(files)}: {test_file} skipped. Requires multiple GPUs.") print(
f"### {i} / {len(files)}: {test_file} skipped. Requires multiple GPUs."
)
continue continue
test_run_cmd = ( test_run_cmd = (
f"{python_executable_path} {launch_option} {test_file} " f"{python_executable_path} {launch_option} {test_file} "
"--micro-batch-size 2 --num-layers 16 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings " "--micro-batch-size 2 --num-layers 16 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings "
"512 --seq-length 512 --global-batch-size 128" "512 --seq-length 512 --global-batch-size 128"
) )
if 'bert' in test_file or 'gpt' in test_file: if "bert" in test_file or "gpt" in test_file:
import torch import torch
num_devices = torch.cuda.device_count() num_devices = torch.cuda.device_count()
test_run_cmd += f" --pipeline-model-parallel-size {num_devices}" test_run_cmd += f" --pipeline-model-parallel-size {num_devices}"
else: else:
test_run_cmd += f" --use-cpu-initialization" test_run_cmd += f" --use-cpu-initialization"
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}") print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
try: try:
output = subprocess.check_output( output = (
test_run_cmd, shell=True subprocess.check_output(test_run_cmd, shell=True)
).decode(sys.stdout.encoding).strip() .decode(sys.stdout.encoding)
.strip()
)
except Exception as e: except Exception as e:
errors.append((test_file, str(e))) errors.append((test_file, str(e)))
else: else:
if '>> passed the test :-)' not in output: if ">> passed the test :-)" not in output:
errors.append((test_file, output)) errors.append((test_file, output))
else: else:
if not errors: if not errors:
...@@ -97,10 +91,9 @@ def run_transformer_tests(): ...@@ -97,10 +91,9 @@ def run_transformer_tests():
class TestTransformer(unittest.TestCase): class TestTransformer(unittest.TestCase):
def test_transformer(self): def test_transformer(self):
run_transformer_tests() run_transformer_tests()
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import utils
from apex.transformer.testing.distributed_test_base import DistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerUtilsTest(DistributedTestBase):
def test_split_tensor_along_last_dim(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
with self.subTest(
tensor_model_paralell_world_size=tensor_model_paralell_world_size
):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = "cpu"
input_tensor = torch.randn((100, 100, 100), device=device)
splits = utils.split_tensor_along_last_dim(input_tensor, 10)
last_dim_shapes = torch.tensor(
[int(split.size()[-1]) for split in splits]
)
self.assertTrue(torch.equal(last_dim_shapes, torch.full((10,), 10),))
parallel_state.destroy_model_parallel()
if __name__ == "__main__":
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment