Unverified Commit aa06107c authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Fixing few issues with multi-process launching. (#2155)



* Fixing few issues with multi-process launching.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 5b3d65cc
......@@ -12,12 +12,12 @@ XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
export XLA_FLAGS="${XLA_BASE_FLAGS}"
NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader)
NUM_RUNS=$(nvidia-smi -L | wc -l)
for ((i=1; i<NUM_RUNS; i++))
do
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_PROC > /dev/null 2>&1 &
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 &
done
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS
wait
......@@ -6,6 +6,7 @@ from functools import partial
import jax
import jax.numpy as jnp
import jax.experimental.multihost_utils as jem
from transformer_engine.jax.dense import grouped_dense as te_grouped_dense
from transformer_engine.jax.quantize import (
......@@ -13,7 +14,7 @@ from transformer_engine.jax.quantize import (
ScalingMode,
)
from utils import assert_allclose
from utils import assert_allclose, dtype_tols
N_GROUP = 8
......@@ -137,9 +138,16 @@ def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis):
out, dx, dw = test_func_jitted(x, w, w_amax)
ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global)
assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2)
assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2)
e4m3_tols = dtype_tols(jnp.float8_e4m3fn)
e5m2_tols = dtype_tols(jnp.float8_e5m2)
out, ref_out = jem.process_allgather((out, ref_out))
dx, ref_dx = jem.process_allgather((dx, ref_dx))
dw, ref_dw = jem.process_allgather((dw, ref_dw))
jnp.allclose(out, ref_out, **e4m3_tols)
jnp.allclose(dx, ref_dx, **e5m2_tols)
jnp.allclose(dw, ref_dw, **e5m2_tols)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment