Unverified Commit 483d9594 authored by jomitchellnv's avatar jomitchellnv Committed by GitHub
Browse files

Adds context parallelism utilities: moving cp shards to diff ranks and pad...


Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129)

* test - adds unit test for cp utilities and the utilites
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* assert line change
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
parent 4903f947
......@@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unit tests for context parallel utils."""
import torch
import unittest
from typing import Tuple
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_batch_on_this_cp_rank,
pad_thd_sequences_for_cp,
generate_positional_ids_for_cp,
)
class TestSequencePadding(unittest.TestCase):
def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self):
"""Test with custom padding values for all tensors."""
# Setup
input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3])
cu_seqlens = torch.tensor([0, 3, 5, 9])
labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100])
positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3])
divisibility_factor = 8
pid = 777
label_pad = -200
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=pid,
padding_label_id=label_pad,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Sequence: [ a a a p p p p p b b pppppp ccccpppp]
print("input_ids_padded: ", input_ids_padded)
print("labels_padded: ", labels_padded)
print("positional_ids_padded: ", positional_ids_padded)
print("cu_seqlens_padded: ", cu_seqlens_padded)
expected_input_ids = torch.tensor(
[
1,
1,
1,
pid,
pid,
pid,
pid,
pid,
2,
2,
pid,
pid,
pid,
pid,
pid,
pid,
3,
3,
3,
3,
pid,
pid,
pid,
pid,
]
)
expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24])
expected_labels_padded = torch.tensor(
[
-100,
-100,
-100,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
-100,
-100,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
-100,
-100,
13,
-100,
label_pad,
label_pad,
label_pad,
label_pad,
]
)
expected_positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
)
assert torch.equal(input_ids_padded, expected_input_ids)
assert torch.equal(labels_padded, expected_labels_padded)
assert torch.equal(positional_ids_padded, expected_positional_ids)
assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)
def test_mixed_sequence_lengths_with_divisibility_factor(self):
"""Test with sequences both shorter and longer than divisibility factor."""
# Setup - divisibility factor 6
# Seq 1: length 2 (shorter than 6, needs 4 padding)
# Seq 2: length 7 (longer than 6, needs 5 padding to reach 12)
# Seq 3: length 4 (shorter than 6, needs 2 padding)
# Seq 4: length 10 (longer than 6, needs 2 padding to reach 12)
input_ids = torch.tensor(
[1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
)
labels = torch.tensor(
[
10,
11,
20,
21,
22,
23,
24,
25,
26,
30,
31,
32,
33,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
]
)
positional_ids = torch.tensor(
[0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
)
cu_seqlens = torch.tensor([0, 2, 9, 13, 23])
divisibility_factor = 6
pid = 999
label_pad = -300
# Execute
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=pid,
padding_label_id=label_pad,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Assert
# Seq 1: [1,1] + 4 pads = 6 total
# Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total
# Seq 3: [3,3,3,3] + 2 pads = 6 total
# Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total
expected_input_ids = torch.tensor(
[
1,
1,
pid,
pid,
pid,
pid, # Seq 1: 2 + 4 padding
2,
2,
2,
2,
2,
2,
2,
pid,
pid,
pid,
pid,
pid, # Seq 2: 7 + 5 padding
3,
3,
3,
3,
pid,
pid, # Seq 3: 4 + 2 padding
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
pid,
pid, # Seq 4: 10 + 2 padding
]
)
expected_labels = torch.tensor(
[
10,
11,
label_pad,
label_pad,
label_pad,
label_pad,
20,
21,
22,
23,
24,
25,
26,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
30,
31,
32,
33,
label_pad,
label_pad,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
label_pad,
label_pad,
]
)
expected_positional_ids = torch.tensor(
[
0,
1,
2,
3,
4,
5, # Seq 1 positions continue through padding
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11, # Seq 2 positions continue
0,
1,
2,
3,
4,
5, # Seq 3 positions continue
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11, # Seq 4 positions continue
]
)
expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36])
self.assertTrue(torch.equal(input_ids_padded, expected_input_ids))
self.assertTrue(torch.equal(labels_padded, expected_labels))
self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids))
self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded))
def test_sequences_longer_than_divisibility_factor(self):
"""Test with all sequences longer than the divisibility factor."""
# Setup - divisibility factor 4, all sequences longer than 4
# Seq 1: length 7 (needs 1 padding to reach 8)
# Seq 2: length 11 (needs 1 padding to reach 12)
# Seq 3: length 5 (needs 3 padding to reach 8)
input_ids = torch.tensor(
[
1,
1,
1,
1,
1,
1,
1, # 7 tokens
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2, # 11 tokens
3,
3,
3,
3,
3, # 5 tokens
]
)
labels = torch.tensor(
[
100,
101,
102,
103,
104,
105,
106,
200,
201,
202,
203,
204,
205,
206,
207,
208,
209,
210,
300,
301,
302,
303,
304,
]
)
positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4]
)
cu_seqlens = torch.tensor([0, 7, 18, 23])
divisibility_factor = 4
pid = 888
label_pad = -400
# Execute
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=pid,
padding_label_id=label_pad,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Assert
# Seq 1: 7 + 1 pad = 8 (divisible by 4)
# Seq 2: 11 + 1 pad = 12 (divisible by 4)
# Seq 3: 5 + 3 pads = 8 (divisible by 4)
expected_input_ids = torch.tensor(
[
1,
1,
1,
1,
1,
1,
1,
pid,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
pid,
3,
3,
3,
3,
3,
pid,
pid,
pid,
]
)
expected_labels = torch.tensor(
[
100,
101,
102,
103,
104,
105,
106,
label_pad,
200,
201,
202,
203,
204,
205,
206,
207,
208,
209,
210,
label_pad,
300,
301,
302,
303,
304,
label_pad,
label_pad,
label_pad,
]
)
expected_positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7]
)
expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28])
self.assertTrue(torch.equal(input_ids_padded, expected_input_ids))
self.assertTrue(torch.equal(labels_padded, expected_labels))
self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids))
self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded))
class TestContextParallelUtils(unittest.TestCase):
"""Test utilities for context parallel functionality."""
def setUp(self):
"""Set up mock distributed environment."""
# Mock torch.distributed functions
self.original_get_world_size = torch.distributed.get_world_size
self.original_get_rank = torch.distributed.get_rank
def tearDown(self):
"""Restore original torch.distributed functions."""
torch.distributed.get_world_size = self.original_get_world_size
torch.distributed.get_rank = self.original_get_rank
def _mock_distributed_env(self, cp_size, cp_rank):
"""Mock the distributed environment for testing."""
def mock_get_world_size(group=None):
return cp_size
def mock_get_rank(group=None):
return cp_rank
torch.distributed.get_world_size = mock_get_world_size
torch.distributed.get_rank = mock_get_rank
def test_cp_rank_slicing_simple_case(self):
"""Test CP rank slicing with a simple 2-rank, single sequence case."""
# Setup: Single sequence of length 8, CP size = 2
# Each sequence gets divided into 2*cp_size = 4 slices of size 2 each
# Rank 0 gets slices [0,1] and [6,7] (first and last)
# Rank 1 gets slices [2,3] and [4,5] (second and second-to-last)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first
labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]])
position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected
cu_seqlens = torch.tensor([0, 8])
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 0 should get indices [0,1] and [6,7]
expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]])
expected_labels_r0 = torch.tensor([[10, 20, 70, 80]])
expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
# Test rank 1
self._mock_distributed_env(cp_size=2, cp_rank=1)
input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 1 should get indices [2,3] and [4,5]
expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]])
expected_labels_r1 = torch.tensor([[30, 40, 50, 60]])
expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5])
self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1))
self.assertTrue(torch.equal(labels_r1, expected_labels_r1))
self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1))
def test_cp_rank_slicing_multiple_sequences(self):
"""Test CP rank slicing with multiple sequences."""
# Setup: Two sequences of length 8 each, CP size = 2
# Total sequence length = 16, cu_seqlens = [0, 8, 16]
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]])
labels = torch.tensor(
[[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]]
)
position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7])
cu_seqlens = torch.tensor([0, 8, 16])
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# For each sequence, rank 0 gets first and last slices
# Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8]
# Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18]
expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]])
expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]])
expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
def test_cp_rank_slicing_with_cp_size_1(self):
"""Test that CP size = 1 returns original tensors unchanged."""
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]])
position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
cu_seqlens = torch.tensor([0, 8])
self._mock_distributed_env(cp_size=1, cp_rank=0)
input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# With CP size = 1, should return original tensors
self.assertTrue(torch.equal(input_ids_result, input_ids))
self.assertTrue(torch.equal(labels_result, labels))
self.assertTrue(torch.equal(pos_ids_result, position_ids))
def test_cp_rank_slicing_sequence_dim_detection(self):
"""Test that the function correctly detects sequence dimension."""
# Test with sequence dimension = 0 (sequence_length, batch_size)
input_ids = torch.tensor(
[[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]
) # (8, 2)
labels = torch.tensor(
[[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]
)
position_ids = torch.tensor(
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]
)
cu_seqlens = torch.tensor([0, 8])
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Should get indices [0,1] and [6,7] along dimension 0
expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]])
expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]])
expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
def test_cp_rank_slicing_mixed_dimensions(self):
"""Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension."""
# Setup: Single sequence of length 8, CP size = 2
# This tests the opposite case from the simple test:
# - input_ids and labels: 1D (no batch dimension)
# - position_ids: 2D (has batch dimension)
input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D
labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D
position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch
cu_seqlens = torch.tensor([0, 8])
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 0 should get indices [0,1] and [6,7]
expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result
expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result
expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim)
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
# Test rank 1
self._mock_distributed_env(cp_size=2, cp_rank=1)
input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 1 should get indices [2,3] and [4,5]
expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result
expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result
expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim)
self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1))
self.assertTrue(torch.equal(labels_r1, expected_labels_r1))
self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1))
def test_integration_with_padding_and_cp_slicing(self):
"""Integration test: pad sequences then slice for CP ranks."""
# Start with unpadded sequences
input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2]
labels = torch.tensor([10, 11, 20, 21, 22])
positional_ids = torch.tensor([0, 1, 0, 1, 2])
cu_seqlens = torch.tensor([0, 2, 5])
divisibility_factor = 4 # Will pad to lengths 4 and 4
# First, pad sequences
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=0,
padding_label_id=-100,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8]
expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0])
self.assertTrue(torch.equal(input_ids_padded, expected_padded))
# Now test CP slicing with cp_size=2
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens_padded,
input_ids_padded.unsqueeze(0),
labels_padded.unsqueeze(0),
positional_ids_padded,
)
# Each sequence of length 4 gets divided into 4 slices of size 1
# Rank 0 gets slices [0] and [3] from each sequence
# Seq 1: indices [0] and [3] -> values [1] and [0]
# Seq 2: indices [4] and [7] -> values [2] and [0]
expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
if __name__ == "__main__":
unittest.main()
......@@ -4,7 +4,7 @@
"""Context Parallelism."""
import os
from typing import List, Union
from typing import List, Union, Tuple
import torch
import transformer_engine_torch as tex
......@@ -3927,3 +3927,212 @@ def attn_forward_func_with_cp(
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
return out
def pad_thd_sequences_for_cp(
input_ids: torch.Tensor,
labels: torch.Tensor,
cu_seqlens: torch.Tensor,
divisibility_factor: int,
padding_token_id: int = 0,
padding_label_id: int = -100,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Pads sequences to be divisible by the divisibility factor.
Args:
input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences
labels: Tensor of shape (1, N) or (N,) containing labels for each token
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
padding_token_id: Token ID to use for padding (default: 0)
padding_label_id: Label ID to use for padding (default: -100)
Returns:
Tuple of:
- input_ids_padded: Padded input_ids tensor
- labels_padded: Padded labels tensor
- cu_seqlens_padded: Cumulative sequence lengths accounting for padding
"""
# Flatten input_ids and labels if needed
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
if labels.dim() == 2:
labels = labels.squeeze(0)
# Compute the sequence lengths from cu_seqlens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# List: amount of padding needed for each sequence (make length a multiple of divisibility_factor)
padding_amounts = [
((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor
- l.item()
for l in seqlens
]
# Extract sequences and labels for each batch item
batch_sequences = [
input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])
]
batch_labels = [
labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])
]
# Pad sequences and labels to required length
input_ids_padded = torch.cat(
[
(
torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)])
if pad > 0
else seq
)
for seq, pad in zip(batch_sequences, padding_amounts)
]
)
labels_padded = torch.cat(
[
(
torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)])
if pad > 0
else seq
)
for seq, pad in zip(batch_labels, padding_amounts)
]
)
# Compute cumulative padded sequence lengths, starting from 0
padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype)
cu_seqlens_padded = torch.cumsum(
torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0
)
return input_ids_padded, labels_padded, cu_seqlens_padded
def generate_positional_ids_for_cp(
cu_seqlens: torch.Tensor,
divisibility_factor: int,
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
"""Generate positional IDs for sequences padded to be divisible by divisibility_factor.
Args:
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
dtype: Data type for the generated positional IDs (default: torch.long)
Returns:
Generated positional_ids tensor where each sequence starts from 0 and continues through padding
"""
# Compute the sequence lengths from cu_seqlens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# List: amount of padding needed for each sequence
padding_amounts = [
((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor
- l.item()
for l in seqlens
]
# Generate positional IDs for each padded sequence (each starts from 0)
padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype)
positional_ids = torch.cat(
[torch.arange(0, int(length), dtype=dtype) for length in padded_lengths]
)
return positional_ids
def get_batch_on_this_cp_rank(
cu_seqlens_padded: torch.Tensor,
input_ids_padded: torch.Tensor,
labels_padded: torch.Tensor,
position_ids_padded: torch.Tensor,
cp_group: torch.distributed.ProcessGroup = None,
qvk_format: str = "thd",
):
"""Slice batch input along sequence dimension into multiple chunks for THD format.
This function is inteded for use in self attention. It will not work for cross attention because
it does not handle the case where the sequence length of the query and key are different.
Which are parallelized across GPUs in a context parallel group.
This version works with variable-length sequences using cumulative sequence lengths.
"""
if qvk_format not in ["thd", "bshd", "sbhd"]:
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
if qvk_format == "thd":
# Get context parallel size and rank
cp_size = torch.distributed.get_world_size(group=cp_group)
if cp_size > 1:
cp_rank = torch.distributed.get_rank(group=cp_group)
# Calculate the chunk sizes for each sequence
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (
cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]
) // total_slices_of_any_sequence
# Process each tensor directly instead of using keys_to_change loop
def process_tensor(val):
if val is None:
return val
# Determine which dimension is the sequence dimension
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
seq_len_val = cu_seqlens_padded[-1].item()
else:
seq_len_val = cu_seqlens_padded[-1]
# Handle 1D tensors (like position_ids that don't have batch dimension)
if val.ndim == 1:
if val.shape[0] == seq_len_val:
current_seq_dim = 0
else:
raise ValueError(
"1D tensor shape doesn't match expected sequence length. Make sure the"
" inputs are in THD format and padded correctly."
)
elif val.ndim >= 2:
if val.shape[1] == seq_len_val:
current_seq_dim = 1
elif val.shape[0] == seq_len_val:
current_seq_dim = 0
else:
raise ValueError(
"Make sure the inputs are in THD format and padded correctly."
)
else:
raise ValueError("Tensor must be at least 1D")
# On this particular rank, for each sequence, get two slices, one from the beginning
# and one from the end.
cp_rank_slices = []
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
# 1st segment
cp_rank_slices.append(
torch.arange(
seq_start + (cp_rank * slice_size),
seq_start + ((cp_rank + 1) * slice_size),
device=val.device,
)
)
# 2nd segment
cp_rank_slices.append(
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=val.device,
)
)
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
# Process each tensor directly
input_ids_padded = process_tensor(input_ids_padded)
labels_padded = process_tensor(labels_padded)
position_ids_padded = process_tensor(position_ids_padded)
else:
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")
return input_ids_padded, labels_padded, position_ids_padded
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment