Unverified Commit 3baaf3ff authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Fix split_overlap_ag `aggregate=True` chunk offset calculation (#1768)



* Fix split_overlap_rs aggregate=True chunk offset calculation
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add unit test for aggregate=True
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix unit test
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 730fd115
...@@ -51,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" ...@@ -51,7 +51,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset() torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [ test_cmd = LAUNCH_CMD + [
str(test_path), str(test_path),
...@@ -78,6 +78,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization): ...@@ -78,6 +78,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
if torch.cuda.get_device_properties(0).major != 9: if torch.cuda.get_device_properties(0).major != 9:
pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).")
test_cmd.append("--atomic") test_cmd.append("--atomic")
if aggregate:
test_cmd.append("--aggregate")
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if ( if (
...@@ -135,12 +137,13 @@ def _run_layer_with_overlap( ...@@ -135,12 +137,13 @@ def _run_layer_with_overlap(
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization): @pytest.mark.parametrize("aggregate", (False, True))
def test_split_all_gather_overlaps(quantization, aggregate):
""" """
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("AG", False, True, False, quantization) _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization)
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
...@@ -150,7 +153,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p): ...@@ -150,7 +153,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p):
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("RS", False, p2p, False, quantization) _run_gemm_with_overlap("RS", False, p2p, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -183,10 +186,10 @@ def test_bulk_overlaps(comm_type, quantization, connections): ...@@ -183,10 +186,10 @@ def test_bulk_overlaps(comm_type, quantization, connections):
" 9.0 (HOPPER ARCH)." " 9.0 (HOPPER ARCH)."
) )
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, quantization) _run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else: else:
_run_gemm_with_overlap(comm_type, True, False, False, quantization) _run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -822,7 +822,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -822,7 +822,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Chunk dims // Chunk dims
std::vector<size_t> input_b_chunk_shape = std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k}); (transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k}; std::vector<size_t> output_chunk_shape = {2 * n_chunk, m};
size_t input_b_chunk_size = 2 * n_chunk * k; size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m; size_t output_chunk_size = 2 * n_chunk * m;
...@@ -853,12 +853,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -853,12 +853,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM // GEMM
auto input_b_chunk = auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
auto output_chunk = auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape); get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape);
auto aux_chunk = auto aux_chunk = (do_gelu)
(do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2,
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) {n_chunk * 2, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype()); : TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment