Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -85,6 +85,7 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row
std::vector<int> scaling_mode = {SF_MODE_X, SF_MODE_Y, 0};
Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
output.set_with_gemm_swizzled_scales(true);
fillUniform(&input);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -284,52 +284,33 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode) {
name_ = name;
const NVTEScalingMode &scaling_mode)
: tensor_(scaling_mode), rowwise_{rowwise}, columnwise_{columnwise}, name_{name} {
// Initialize RNG
const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed);
rowwise_ = rowwise;
columnwise_ = columnwise;
size_t total_size = bytes(shape, type);
void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr;
cpu_data_columnwise_ = nullptr;
amax_cpu_data_ = nullptr;
scale_cpu_data_ = nullptr;
rowwise_scale_inv_cpu_data_ = nullptr;
columnwise_scale_inv_cpu_data_ = nullptr;
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
// Make sure shape is valid
if (columnwise) {
NVTE_CHECK(shape.ndim >= 2);
}
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
// Shape after flattening to 2D
NVTEShape flattened_shape;
{
std::vector<size_t> flattened_shape_vec;
if (shape.ndim > 0) {
flattened_shape_vec.push_back(product(shape, 0, shape.ndim - 1));
flattened_shape_vec.push_back(shape.data[shape.ndim - 1]);
} else {
// Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
flattened_shape_vec.resize(2, 1);
}
flattened_shape = convertShape(flattened_shape_vec);
}
if (columnwise) {
columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
}
tensor_ = TensorWrapper(scaling_mode);
// Allocate and initialize data
void *dptr_rowwise = nullptr, *dptr_columnwise = nullptr;
const size_t total_size = bytes(shape, type);
if (total_size != 0) {
if (rowwise) {
cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*)
......@@ -345,16 +326,57 @@ Tensor::Tensor(const std::string& name,
}
}
// Set tensor row-wise data
if (rowwise) {
#if FP4_TYPE_SUPPORTED
const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
#else
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
}
// Set tensor column-wise data
if (columnwise) {
// Determine shape of column-wise data
std::vector<size_t> columnwise_shape_vec;
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
// Column-wise data shape is transposed
if (shape.ndim > 0) {
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
}
break;
}
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING: {
// Column-wise data matches shape
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
break;
}
default:
NVTE_ERROR("Unrecognized scaling mode (", (size_t)scaling_mode, ").");
}
const auto columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(),
columnwise_shape_vec.size());
#if FP4_TYPE_SUPPORTED
// Set column-wise data buffer
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
#else
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
#endif
}
// Configure scales, amaxes, and other tensor buffers
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
......@@ -386,7 +408,7 @@ Tensor::Tensor(const std::string& name,
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(flattened_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -307,6 +307,10 @@ class Tensor {
tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
}
void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){
tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
}
void to_cpu() const;
void from_cpu() const;
void set_scale(float scale);
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -63,12 +63,6 @@ int main(int argc, char* argv[]) {
return ret;
}
bool IsMulticastSupported(int device_id) {
int supported = 0;
CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id));
return supported;
}
int GetDeviceComputeCapability(int device_id) {
int major{};
int minor{};
......@@ -369,11 +363,6 @@ struct GemmAr : public CommGemmFixure {
nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
void SetUp() override {
if (!IsMulticastSupported(rank_))
GTEST_SKIP() << "Multicast is not supported on device " << rank_;
}
};
TEST_P(AgGemm, Gemm) {
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""conftest for tests/jax"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import operator
......@@ -12,7 +12,7 @@ from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, is_devices_enough
from utils import assert_allclose, is_devices_enough, is_devices_equal
def generate_configs():
......@@ -49,7 +49,11 @@ def generate_context_parallel_configs_for_attn():
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
# Run only those dp,cp,tp combinations which require exactly ndev GPUs.
# For e.g., if num_GPUs is 8 and ndev=8 , all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are picked.
# However, if num_GPUs is 8 and ndev=4, then all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are ignored.
# To explicitly pick combinations associated with ndev=4, one can set CUDA_VISIBLE_DEVICES=0,1,2,3, thereby forcing num_GPUs to 4 instead of 8.
if is_devices_equal(ndev):
# Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations)
if cp != 1:
configsL1.append(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -18,6 +18,7 @@ from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
......@@ -66,6 +67,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
......@@ -80,6 +82,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
......@@ -109,6 +112,7 @@ class TestDistributedSelfAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -142,6 +146,14 @@ class TestDistributedSelfAttn:
],
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn(
self,
device_count,
......@@ -153,6 +165,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
):
self.impl_test_self_attn(
device_count,
......@@ -164,6 +177,7 @@ class TestDistributedSelfAttn:
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy=False,
)
......@@ -175,8 +189,23 @@ class TestDistributedSelfAttn:
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type,
bias_shape,
softmax_type,
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
......@@ -189,6 +218,7 @@ class TestDistributedSelfAttn:
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
softmax_type,
use_shardy=True,
)
......@@ -213,8 +243,24 @@ class TestDistributedCrossAttn:
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
softmax_type,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
......@@ -230,6 +276,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
......@@ -252,6 +299,7 @@ class TestDistributedCrossAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -279,14 +327,14 @@ DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
]
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
# Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCPx2-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCPx2-16-64"),
]
class TestDistributedContextParallelSelfAttn:
# TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests
def impl_test_context_parallel_attn(
self,
device_count,
......@@ -303,12 +351,14 @@ class TestDistributedContextParallelSelfAttn:
use_shardy,
use_scan_ring=False,
window_size=None,
stripe_size=None,
num_segments_per_seq=None,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
if not load_balanced and (
cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER
):
pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
......@@ -322,6 +372,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape = None
dropout_prob = 0.0
is_training = True
# Context parallel does not support softmax_offset
softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
dp_size, cp_size, tp_size = mesh_shape
batch, seqlen, num_head, hidden = data_shape
......@@ -332,7 +384,6 @@ class TestDistributedContextParallelSelfAttn:
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
runner = FusedAttnRunner(
batch,
seqlen,
......@@ -343,6 +394,7 @@ class TestDistributedContextParallelSelfAttn:
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -350,6 +402,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape,
window_size,
SeqDescFormat.SegmentIDs,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -366,6 +420,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
attn_bias_type,
mask_type,
softmax_type,
dropout_prob,
num_head,
num_kv_heads,
......@@ -401,7 +456,7 @@ class TestDistributedContextParallelSelfAttn:
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
......@@ -418,6 +473,8 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
):
if qkv_layout.is_thd():
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
......@@ -434,6 +491,72 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=True,
)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED")],
)
@pytest.mark.parametrize(
"stripe_size",
[pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")],
)
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((5, 0), id="window_size(8, 0)"),
],
)
@pytest.mark.parametrize(
"num_segments_per_seq",
[pytest.param(5, id="SEG-5")],
)
def test_context_parallel_allgather_striped_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
window_size,
stripe_size,
num_segments_per_seq,
):
if not qkv_layout.is_thd():
pytest.skip("Only THD layout is supported for CP + AG + Striped attention")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
use_shardy=False,
window_size=window_size,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
......@@ -462,6 +585,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -525,6 +650,8 @@ class TestDistributedContextParallelSelfAttn:
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
# Set the stripe size to 1 (ring attention only support stripe_size=1)
stripe_size = 1 if qkv_layout.is_thd() else None
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -540,6 +667,7 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=False,
use_scan_ring=use_scan,
window_size=window_size,
stripe_size=stripe_size,
)
@pytest_parametrize_wrapper(
......@@ -564,6 +692,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
):
kv_groups = 8
# Set the stripe size to 1 (ring attention only support stripe_size=1)
stripe_size = 1 if qkv_layout.is_thd() else None
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -578,6 +708,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
stripe_size=stripe_size,
)
......@@ -587,31 +718,39 @@ REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
"L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}
REORDER_STRATEGY = [
pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"),
pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"),
]
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD])
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
"reorder_strategy, stripe_size",
REORDER_STRATEGY,
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0
if reorder_strategy == ReorderStrategy.Striped:
seq_lens = shape[seq_dim]
if seq_lens < (cp_size * stripe_size):
pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}")
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size)
assert jnp.array_equal(inversed, ref)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment