Commit 29271c40 authored by tabuchixiangcai3's avatar tabuchixiangcai3
Browse files

[DCU]Fix MPI root support, enable int8 simulation and batched_inear to access...


[DCU]Fix MPI root support, enable int8 simulation and batched_inear to access non-existent. main_grad
Signed-off-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
parent c9eab7e7
......@@ -35,7 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests
......@@ -46,7 +46,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
NVTE_INT8_SIM_FP8=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
......
......@@ -432,7 +432,7 @@ def test_fuser_ops_with_userbuffers(
command = []
if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe))
command.extend(("mpirun", "-np", str(world_size), "--allow-run-as-root", "--oversubscribe", "--quiet", python_exe))
else:
command.extend(("torchrun", f"--nproc_per_node={world_size}"))
......
......@@ -153,7 +153,13 @@ class _BatchLinear(torch.autograd.Function):
if cpu_offloading:
if fuse_wgrad_accumulation:
for w in weights:
w.main_grad.weight_offloading = True
if getattr(w, "main_grad", None) is not None:
w.main_grad.weight_offloading = True
else:
# Optional: log a warning if fuse requested but buffer missing
# logger = logging.getLogger("BatchLinear")
# logger.debug("fuse_wgrad_accumulation=True but weight.main_grad is missing; skipping weight_offloading for this weight.")
pass
for w in weights:
w.weight_offloading = True
for t in saved_inputmats:
......@@ -162,7 +168,8 @@ class _BatchLinear(torch.autograd.Function):
for i in range(num_gemms):
weights[i].offloading_activation = False
weights[i].main_grad.offloading_activation = False
if getattr(weights[i], "main_grad", None) is not None:
weights[i].main_grad.offloading_activation = False
if weights_fp8[i] is not None:
weights_fp8[i].offloading_activation = False
......@@ -553,8 +560,17 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
# Ensure main_grad buffers exist when fuse_wgrad_accumulation is enabled.
# Skip allocation under meta device (deferred init).
self.reset_parameters(defer_init=(device == "meta"))
if self.fuse_wgrad_accumulation and device != "meta":
for i in range(int(self.num_gemms)):
w = getattr(self, f"weight{i}")
if getattr(w, "main_grad", None) is None:
# use float32 buffer for main_grad (tests use float32)
w.main_grad = torch.empty_like(w, dtype=torch.float32, device=w.device)
w.main_grad.zero_()
# For RPL, bias has to be added after TP collectives
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment