Unverified Commit 483d9594 authored by jomitchellnv's avatar jomitchellnv Committed by GitHub
Browse files

Adds context parallelism utilities: moving cp shards to diff ranks and pad...


Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129)

* test - adds unit test for cp utilities and the utilites
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* assert line change
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
parent 4903f947
......@@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
This diff is collapsed.
......@@ -4,7 +4,7 @@
"""Context Parallelism."""
import os
from typing import List, Union
from typing import List, Union, Tuple
import torch
import transformer_engine_torch as tex
......@@ -3927,3 +3927,212 @@ def attn_forward_func_with_cp(
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
return out
def pad_thd_sequences_for_cp(
input_ids: torch.Tensor,
labels: torch.Tensor,
cu_seqlens: torch.Tensor,
divisibility_factor: int,
padding_token_id: int = 0,
padding_label_id: int = -100,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Pads sequences to be divisible by the divisibility factor.
Args:
input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences
labels: Tensor of shape (1, N) or (N,) containing labels for each token
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
padding_token_id: Token ID to use for padding (default: 0)
padding_label_id: Label ID to use for padding (default: -100)
Returns:
Tuple of:
- input_ids_padded: Padded input_ids tensor
- labels_padded: Padded labels tensor
- cu_seqlens_padded: Cumulative sequence lengths accounting for padding
"""
# Flatten input_ids and labels if needed
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
if labels.dim() == 2:
labels = labels.squeeze(0)
# Compute the sequence lengths from cu_seqlens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# List: amount of padding needed for each sequence (make length a multiple of divisibility_factor)
padding_amounts = [
((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor
- l.item()
for l in seqlens
]
# Extract sequences and labels for each batch item
batch_sequences = [
input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])
]
batch_labels = [
labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])
]
# Pad sequences and labels to required length
input_ids_padded = torch.cat(
[
(
torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)])
if pad > 0
else seq
)
for seq, pad in zip(batch_sequences, padding_amounts)
]
)
labels_padded = torch.cat(
[
(
torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)])
if pad > 0
else seq
)
for seq, pad in zip(batch_labels, padding_amounts)
]
)
# Compute cumulative padded sequence lengths, starting from 0
padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype)
cu_seqlens_padded = torch.cumsum(
torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0
)
return input_ids_padded, labels_padded, cu_seqlens_padded
def generate_positional_ids_for_cp(
cu_seqlens: torch.Tensor,
divisibility_factor: int,
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
"""Generate positional IDs for sequences padded to be divisible by divisibility_factor.
Args:
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
dtype: Data type for the generated positional IDs (default: torch.long)
Returns:
Generated positional_ids tensor where each sequence starts from 0 and continues through padding
"""
# Compute the sequence lengths from cu_seqlens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# List: amount of padding needed for each sequence
padding_amounts = [
((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor
- l.item()
for l in seqlens
]
# Generate positional IDs for each padded sequence (each starts from 0)
padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype)
positional_ids = torch.cat(
[torch.arange(0, int(length), dtype=dtype) for length in padded_lengths]
)
return positional_ids
def get_batch_on_this_cp_rank(
cu_seqlens_padded: torch.Tensor,
input_ids_padded: torch.Tensor,
labels_padded: torch.Tensor,
position_ids_padded: torch.Tensor,
cp_group: torch.distributed.ProcessGroup = None,
qvk_format: str = "thd",
):
"""Slice batch input along sequence dimension into multiple chunks for THD format.
This function is inteded for use in self attention. It will not work for cross attention because
it does not handle the case where the sequence length of the query and key are different.
Which are parallelized across GPUs in a context parallel group.
This version works with variable-length sequences using cumulative sequence lengths.
"""
if qvk_format not in ["thd", "bshd", "sbhd"]:
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
if qvk_format == "thd":
# Get context parallel size and rank
cp_size = torch.distributed.get_world_size(group=cp_group)
if cp_size > 1:
cp_rank = torch.distributed.get_rank(group=cp_group)
# Calculate the chunk sizes for each sequence
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (
cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]
) // total_slices_of_any_sequence
# Process each tensor directly instead of using keys_to_change loop
def process_tensor(val):
if val is None:
return val
# Determine which dimension is the sequence dimension
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
seq_len_val = cu_seqlens_padded[-1].item()
else:
seq_len_val = cu_seqlens_padded[-1]
# Handle 1D tensors (like position_ids that don't have batch dimension)
if val.ndim == 1:
if val.shape[0] == seq_len_val:
current_seq_dim = 0
else:
raise ValueError(
"1D tensor shape doesn't match expected sequence length. Make sure the"
" inputs are in THD format and padded correctly."
)
elif val.ndim >= 2:
if val.shape[1] == seq_len_val:
current_seq_dim = 1
elif val.shape[0] == seq_len_val:
current_seq_dim = 0
else:
raise ValueError(
"Make sure the inputs are in THD format and padded correctly."
)
else:
raise ValueError("Tensor must be at least 1D")
# On this particular rank, for each sequence, get two slices, one from the beginning
# and one from the end.
cp_rank_slices = []
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
# 1st segment
cp_rank_slices.append(
torch.arange(
seq_start + (cp_rank * slice_size),
seq_start + ((cp_rank + 1) * slice_size),
device=val.device,
)
)
# 2nd segment
cp_rank_slices.append(
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=val.device,
)
)
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
# Process each tensor directly
input_ids_padded = process_tensor(input_ids_padded)
labels_padded = process_tensor(labels_padded)
position_ids_padded = process_tensor(position_ids_padded)
else:
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")
return input_ids_padded, labels_padded, position_ids_padded
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment