Unverified Commit 3baaf3ff authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Fix split_overlap_ag `aggregate=True` chunk offset calculation (#1768)



* Fix split_overlap_rs aggregate=True chunk offset calculation
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add unit test for aggregate=True
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix unit test
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 730fd115
......@@ -51,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -78,6 +78,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).")
test_cmd.append("--atomic")
if aggregate:
test_cmd.append("--aggregate")
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if (
......@@ -135,12 +137,13 @@ def _run_layer_with_overlap(
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization):
@pytest.mark.parametrize("aggregate", (False, True))
def test_split_all_gather_overlaps(quantization, aggregate):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, quantization)
_run_gemm_with_overlap("AG", False, True, False, aggregate, quantization)
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
......@@ -150,7 +153,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p):
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, quantization)
_run_gemm_with_overlap("RS", False, p2p, False, False, quantization)
@pytest.mark.parametrize(
......@@ -183,10 +186,10 @@ def test_bulk_overlaps(comm_type, quantization, connections):
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
_run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
_run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
@pytest.mark.parametrize(
......
......@@ -822,7 +822,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
std::vector<size_t> output_chunk_shape = {2 * n_chunk, m};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
......@@ -853,13 +853,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape);
auto aux_chunk = (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2,
{n_chunk * 2, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment