Commit 8fb50d09 authored by yuguo's avatar yuguo
Browse files

[DCU] tmp fix

parent b71ea424
......@@ -312,7 +312,7 @@ def _main(opts):
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type,
num_max_streams=2 if IS_HIP_EXTENSION else 3,
num_max_streams=1 if IS_HIP_EXTENSION else 3,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
......@@ -401,7 +401,7 @@ def _main(opts):
)
# Allocate cuBLAS workspace
workspace_size = 2 * get_cublas_workspace_size_bytes()
workspace_size = 1 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
......@@ -773,17 +773,17 @@ def _main(opts):
"NUMERICAL CHECK FAILED: "
+ f"Outputs not close enough at index {m.item()} "
+ f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
+ f"rel. deviation = {rel_err} (tol = {rtol}) | "
+ f"abs. deviation = {abs_err} (tol = {atol})"
)
else:
numerics_info = "NUMERICAL CHECK PASSED: "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
numerics_info += f"rel. deviation = {rel_err} (tol = {rtol})" + (
" | " if abs_err < atol else ""
)
if abs_err <= atol:
numerics_info += f"abs. error = {abs_err} (tol = {atol})"
numerics_info += f"abs. deviation = {abs_err} (tol = {atol})"
dist_print(
numerics_info, src=0, section=True, info=True, error=numerics_failed, group=tp_group
......
......@@ -3,6 +3,8 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# UB_SKIPMC=1 mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_layer_with_overlap.py --seed=42 --seq-length=4096 --batch-size=2 --num-heads=96 --head-dim=128 --layer-type LayerNormLinear --linear-parallel-mode column --num-layers 1 --overlap-rs-dgrad
# NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=1 UB_SKIPMC=1 mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_layer_with_overlap.py --seed=42 --seq-length=4096 --batch-size=2 --num-heads=96 --head-dim=128 --layer-type MultiheadAttention --num-layers 1 --overlap-rs-dgrad
import os
import sys
......@@ -266,17 +268,17 @@ def _compare_tensors(name, test, ref, rtol, atol):
"NUMERICAL CHECK FAILED: "
+ f"{name} not close enough at index {m.item()} "
+ f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
+ f"rel. deviation = {rel_err} (tol = {rtol}) | "
+ f"abs. deviation = {abs_err} (tol = {atol})"
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: {name} | "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
numerics_info += f"rel. deviation = {rel_err} (tol = {rtol})" + (
" | " if abs_err <= atol else "."
)
if abs_err <= atol:
numerics_info += f" abs. error = {abs_err} (tol = {atol})"
numerics_info += f" abs. deviation = {abs_err} (tol = {atol})"
return numerics_failed, numerics_info
......
......@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
_NUM_MAX_UB_STREAMS = 1 if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
......@@ -357,7 +357,7 @@ def initialize_ub(
helper, # Helper for torch.distributed callbacks during bootstrapping
tp_size, # Tensor-parallel group size (may be different than local_size)
num_splits=num_splits,
num_max_streams=_NUM_MAX_UB_STREAMS - 1 if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS,
num_max_streams=_NUM_MAX_UB_STREAMS,
comm_cga_size=cga_size,
num_comm_sm=num_sm,
set_sm_margin=set_sm_margin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment