Unverified Commit a65ad37e authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Test_multiprocessing_encoder with process spawn in bash (#1394)



* add test_multiprocessing_encoder with processing spawning in bash

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7b861e75
...@@ -12,3 +12,10 @@ def is_bf16_supported(): ...@@ -12,3 +12,10 @@ def is_bf16_supported():
"""Return if BF16 has hardware supported""" """Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0) gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80 return gpu_arch >= 80
@lru_cache
def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 90
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for test_multiprocessing_encoder"""
import pytest
def pytest_addoption(parser):
"""Pytest hook for test_multiprocessing_encoder"""
parser.addoption("--num-process", action="store", default=0)
parser.addoption("--process-id", action="store", default=0)
@pytest.fixture(autouse=True)
def multiprocessing_parses(request):
"""Fixture for querying num-process and process-id"""
if request.cls:
request.cls.num_process = int(request.config.getoption("--num-process"))
request.cls.process_id = int(request.config.getoption("--process-id"))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" """Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
import argparse import argparse
import multiprocessing as mp
import os import os
import unittest import unittest
from functools import partial from functools import partial
import pytest
import flax import flax
import jax import jax
...@@ -21,10 +21,10 @@ from flax.training import train_state ...@@ -21,10 +21,10 @@ from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, is_fp8_supported
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
...@@ -252,7 +252,6 @@ def eval_model( ...@@ -252,7 +252,6 @@ def eval_model(
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"]) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
...@@ -342,6 +341,9 @@ def get_state_sharding(state, params_sharding): ...@@ -342,6 +341,9 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
if args.process_id == 0:
nltk.download("punkt_tab")
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
jax.distributed.initialize( jax.distributed.initialize(
...@@ -551,69 +553,41 @@ def encoder_parser(args): ...@@ -551,69 +553,41 @@ def encoder_parser(args):
return parser.parse_args(args) return parser.parse_args(args)
def query_gpu(q): @pytest.mark.usefixtures("multiprocessing_parses")
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])
def unittest_query_gpu():
r"""
It is only used by TestEncoder.
The `jax.distributed.initialize` must be called before any other JAX or Flax API,
otherwise `jax.local_devices` will be incorrect.
Thus, fork another process to query number of GPUs and FP8 capability.
"""
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() gpu_has_fp8 = is_fp8_supported()
gpu_has_bf16 = is_bf16_supported()
def exec(self, use_fp8): def exec(self, use_fp8):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
num_gpu = self.num_gpu args = encoder_parser([])
num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size dp_size = num_gpu // tp_size
batch_size = 64 // dp_size batch_size = 64 // dp_size
arg_list = []
for i in range(num_gpu):
args = encoder_parser([])
args.num_process = num_gpu
args.use_fp8 = use_fp8 args.use_fp8 = use_fp8
args.batch_size = batch_size args.batch_size = batch_size
args.test_batch_size = batch_size args.test_batch_size = batch_size
args.process_id = i args.num_process = num_gpu
arg_list.append(args) args.process_id = self.process_id
with mp.Pool(self.num_gpu) as p:
results = p.map(train_and_evaluate, arg_list)
return results return train_and_evaluate(args)
@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
results = self.exec(False) result = self.exec(False)
actual = results[0] assert result[0] < 0.45 and result[1] > 0.79
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason) @unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8")
def test_te_fp8(self): def test_te_fp8(self):
"""Test Transformer Engine with FP8""" """Test Transformer Engine with FP8"""
results = self.exec(True) result = self.exec(True)
actual = results[0] assert result[0] < 0.45 and result[1] > 0.79
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt ...@@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment