Unverified Commit a65ad37e authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Test_multiprocessing_encoder with process spawn in bash (#1394)



* add test_multiprocessing_encoder with processing spawning in bash

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7b861e75
......@@ -12,3 +12,10 @@ def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80
@lru_cache
def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 90
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for test_multiprocessing_encoder"""
import pytest
def pytest_addoption(parser):
"""Pytest hook for test_multiprocessing_encoder"""
parser.addoption("--num-process", action="store", default=0)
parser.addoption("--process-id", action="store", default=0)
@pytest.fixture(autouse=True)
def multiprocessing_parses(request):
"""Fixture for querying num-process and process-id"""
if request.cls:
request.cls.num_process = int(request.config.getoption("--num-process"))
request.cls.process_id = int(request.config.getoption("--process-id"))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
......@@ -3,10 +3,10 @@
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
import argparse
import multiprocessing as mp
import os
import unittest
from functools import partial
import pytest
import flax
import jax
......@@ -21,10 +21,10 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, is_fp8_supported
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
......@@ -252,7 +252,6 @@ def eval_model(
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......@@ -342,6 +341,9 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
if args.process_id == 0:
nltk.download("punkt_tab")
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
jax.distributed.initialize(
......@@ -551,69 +553,41 @@ def encoder_parser(args):
return parser.parse_args(args)
def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])
def unittest_query_gpu():
r"""
It is only used by TestEncoder.
The `jax.distributed.initialize` must be called before any other JAX or Flax API,
otherwise `jax.local_devices` will be incorrect.
Thus, fork another process to query number of GPUs and FP8 capability.
"""
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason
@pytest.mark.usefixtures("multiprocessing_parses")
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()
gpu_has_fp8 = is_fp8_supported()
gpu_has_bf16 = is_bf16_supported()
def exec(self, use_fp8):
"""Run 3 epochs for testing"""
num_gpu = self.num_gpu
args = encoder_parser([])
num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size
batch_size = 64 // dp_size
arg_list = []
for i in range(num_gpu):
args = encoder_parser([])
args.num_process = num_gpu
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.process_id = i
arg_list.append(args)
with mp.Pool(self.num_gpu) as p:
results = p.map(train_and_evaluate, arg_list)
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.num_process = num_gpu
args.process_id = self.process_id
return results
return train_and_evaluate(args)
@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
result = self.exec(False)
assert result[0] < 0.45 and result[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
@unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8")
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
results = self.exec(True)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
result = self.exec(True)
assert result[0] < 0.45 and result[1] > 0.79
if __name__ == "__main__":
......
......@@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment