"vscode:/vscode.git/clone" did not exist on "fe27bf1cb3699bd4c196126ac6594f39be2c0eb5"
Unverified Commit 855fa653 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Support SWA in CP Ring Attn THD striped sharding (#1810)



* Support SWA in CP Ring Attn THD striped sharding
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add some comments; move check to _FusedAttnCPWithP2PHelper.check_supported()
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Remove unused check
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
parent 4292653c
...@@ -289,6 +289,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -289,6 +289,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy, cp_strategy,
use_shardy, use_shardy,
use_scan_ring=False, use_scan_ring=False,
window_size=None,
): ):
if qkv_layout.is_thd(): if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER: if cp_strategy == CPStrategy.ALL_GATHER:
...@@ -333,7 +334,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -333,7 +334,7 @@ class TestDistributedContextParallelSelfAttn:
is_training, is_training,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
None, window_size,
SeqDescFormat.SegmentIDs, SeqDescFormat.SegmentIDs,
number_of_devices=device_count, number_of_devices=device_count,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
...@@ -476,6 +477,13 @@ class TestDistributedContextParallelSelfAttn: ...@@ -476,6 +477,13 @@ class TestDistributedContextParallelSelfAttn:
"use_scan", "use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
) )
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((20, 0), id="window_size(20, 0)"),
],
)
def test_context_parallel_ring_attn( def test_context_parallel_ring_attn(
self, self,
device_count, device_count,
...@@ -489,7 +497,15 @@ class TestDistributedContextParallelSelfAttn: ...@@ -489,7 +497,15 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
use_scan, use_scan,
window_size,
): ):
if window_size != (-1, -1) and not qkv_layout.is_thd():
pytest.skip("Sliding window attention is only supported for THD layout")
if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
pytest.skip(
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
self.impl_test_context_parallel_attn( self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -504,6 +520,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -504,6 +520,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy.RING, CPStrategy.RING,
use_shardy=False, use_shardy=False,
use_scan_ring=use_scan, use_scan_ring=use_scan,
window_size=window_size,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -41,6 +41,7 @@ from ..sharding import ( ...@@ -41,6 +41,7 @@ from ..sharding import (
all_reduce_sum_along_dp_fsdp, all_reduce_sum_along_dp_fsdp,
get_mesh_axis_size, get_mesh_axis_size,
get_mesh_axis_rank, get_mesh_axis_rank,
get_mesh_axis_rank_host,
get_all_mesh_axes, get_all_mesh_axes,
num_of_devices, num_of_devices,
with_sharding_constraint, with_sharding_constraint,
...@@ -74,6 +75,7 @@ __all__ = [ ...@@ -74,6 +75,7 @@ __all__ = [
"window_size", "window_size",
"context_parallel_load_balanced", "context_parallel_load_balanced",
"cp_axis", "cp_axis",
"cp_striped_window_size",
], ],
) )
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -92,6 +94,7 @@ class _FusedAttnConfig: ...@@ -92,6 +94,7 @@ class _FusedAttnConfig:
window_size: Tuple[int, int] window_size: Tuple[int, int]
context_parallel_load_balanced: bool context_parallel_load_balanced: bool
cp_axis: str cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -398,6 +401,13 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -398,6 +401,13 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if config.cp_striped_window_size is not None:
window_size_left = config.cp_striped_window_size[0]
window_size_right = config.cp_striped_window_size[1]
else:
window_size_left = config.window_size[0]
window_size_right = config.window_size[1]
return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)( return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
ctx, ctx,
q, q,
...@@ -429,8 +439,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -429,8 +439,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
qkv_layout=int(config.qkv_layout.value), qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training, is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0], window_size_left=window_size_left,
window_size_right=config.window_size[1], window_size_right=window_size_right,
) )
@staticmethod @staticmethod
...@@ -790,6 +800,13 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -790,6 +800,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if config.cp_striped_window_size is not None:
window_size_left = config.cp_striped_window_size[0]
window_size_right = config.cp_striped_window_size[1]
else:
window_size_left = config.window_size[0]
window_size_right = config.window_size[1]
return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)( return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
ctx, ctx,
q, q,
...@@ -824,8 +841,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -824,8 +841,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
qkv_layout=int(config.qkv_layout.value), qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training, is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0], window_size_left=window_size_left,
window_size_right=config.window_size[1], window_size_right=window_size_right,
) )
@staticmethod @staticmethod
...@@ -1177,6 +1194,7 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1177,6 +1194,7 @@ class _FusedAttnCPWithAllGatherHelper:
window_size=self.config.window_size, window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
) )
def all_gather_kv(self, k, v): def all_gather_kv(self, k, v):
...@@ -1616,6 +1634,16 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1616,6 +1634,16 @@ class _FusedAttnCPWithP2PHelper:
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment" " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
) )
# If using scanloop, idx in scan_kv_block() will be a traced device value, but
# _normalize_window_size_for_cp_striped() requires all parameters to be host values
is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1
is_thd_layout = self.config.qkv_layout.is_thd()
is_sliding_window = self.config.window_size[0] != -1
if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop():
raise ValueError(
f"{header} with THD format and sliding window does not support using scan loop"
)
def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention.""" """Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig( return _FusedAttnConfig(
...@@ -1629,6 +1657,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1629,6 +1657,7 @@ class _FusedAttnCPWithP2PHelper:
window_size=self.config.window_size, window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
) )
def stack_kv(self, k, v): def stack_kv(self, k, v):
...@@ -2100,6 +2129,67 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2100,6 +2129,67 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedRingAttnBwdPrimitive) register_primitive(FusedRingAttnBwdPrimitive)
def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size):
"""
Adjust window size with cp_size for striped sharding, where both q_pos and
kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...].
Example 1:
q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0).
q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is:
0 8 16 24 32
----------------
0 | 1 0 0 0 0
8 | 1 1 0 0 0
16 | 0 1 1 0 0
24 | 0 0 1 1 0
32 | 0 0 0 1 1
SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
Adjusted window size = (1, 0).
Example 2:
q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8,
window_size = (15, 0). The effective mask is:
1 9 17 25 33
----------------
0 | 0 0 0 0 0
8 | 1 0 0 0 0
16 | 1 1 0 0 0
24 | 0 1 1 0 0
32 | 0 0 1 1 0
SequenceDescriptor outputs:
q_seqlen = [4, ...], q_seq_offsets = [1, ...],
kv_seqlen = [4, ...], kv_seq_offsets = [0, ...].
If diagonal are all 1, left window size = 2. Now since diagonal are all 0,
we need to use left window size = 2 - 1 = 1 to make cuDNN work.
Example 3:
q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8,
window_size = (22, 0). The effective mask is:
0 8 16 24 32
----------------
7 | 1 0 0 0 0
15 | 1 1 0 0 0
23 | 0 1 1 0 0
31 | 0 0 1 1 0
39 | 0 0 0 1 1
SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
Adjust window size = (1, 0).
"""
left_limit = q_pos0 - window_size[0]
right_limit = q_pos0 + window_size[1]
# Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size
left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1
right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1
left_steps = max(left_steps, 0)
right_steps = max(right_steps, 0)
# If kv_pos0 > q_pos0, we must reduce left window size by 1
shift = 1 if kv_pos0 > q_pos0 else 0
left_steps = left_steps - shift
return left_steps, right_steps
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
""" """
Fused Striped Ring Attention Forward Primitive Fused Striped Ring Attention Forward Primitive
...@@ -2108,9 +2198,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2108,9 +2198,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
@staticmethod @staticmethod
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel: if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
...@@ -2156,6 +2243,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2156,6 +2243,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
subblock_config = config subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
batch, q_max_seqlen, head, _ = q.shape batch, q_max_seqlen, head, _ = q.shape
...@@ -2176,22 +2264,36 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2176,22 +2264,36 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( def compute(config):
q, return FusedAttnFwdPrimitive.impl(
kv, q,
_not_used, kv,
bias, _not_used,
seed, bias,
q_seqlen, seed,
kv_seqlen, q_seqlen,
q_seq_offsets, kv_seqlen,
k_seq_offsets, q_seq_offsets,
q_segment_ids, k_seq_offsets,
kv_segment_ids, q_segment_ids,
q_segment_pos, kv_segment_ids,
kv_segment_pos, q_segment_pos,
subblock_config, kv_segment_pos,
) config,
)
if config.window_size != (-1, -1):
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
cp_rank, kv_src_rank, cp_size, config.window_size
)
current_config = replace(
subblock_config, cp_striped_window_size=cp_striped_window_size
)
else:
current_config = subblock_config
output_per_step, softmax_aux_per_step, _ = compute(current_config)
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1)) softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
...@@ -2244,9 +2346,6 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2244,9 +2346,6 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
@staticmethod @staticmethod
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel: if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
...@@ -2290,13 +2389,15 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2290,13 +2389,15 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
subblock_config = config subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_size = get_mesh_axis_size(config.cp_axis, mesh)
# We need cp_rank to be a host value for adjust_cp_striped_window_size()
cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q) dq = jnp.zeros_like(q)
dkv = jnp.zeros_like(kv) dkv = jnp.zeros_like(kv)
dbias = jnp.zeros_like(bias) dbias = jnp.zeros_like(bias)
def scan_kv_block(_idx, carry): def scan_kv_block(idx, carry):
kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry
# Start communication that feeds the next iteration. # Start communication that feeds the next iteration.
...@@ -2306,7 +2407,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2306,7 +2407,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute(): def compute(config):
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
...@@ -2324,11 +2425,22 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2324,11 +2425,22 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_ids, kv_segment_ids,
q_segment_pos, q_segment_pos,
kv_segment_pos, kv_segment_pos,
config=subblock_config, config=config,
) )
return dq_per_step, dkv_per_step, dbias_per_step return dq_per_step, dkv_per_step, dbias_per_step
dq_per_step, dkv_per_step, dbias_per_step = compute() if config.window_size != (-1, -1):
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
cp_rank, kv_src_rank, cp_size, config.window_size
)
current_config = replace(
subblock_config, cp_striped_window_size=cp_striped_window_size
)
else:
current_config = subblock_config
dq_per_step, dkv_per_step, dbias_per_step = compute(current_config)
kv_next, dkv = jnp.unstack(kv_dkv) kv_next, dkv = jnp.unstack(kv_dkv)
dq += dq_per_step dq += dq_per_step
...@@ -2462,6 +2574,7 @@ def fused_attn_fwd( ...@@ -2462,6 +2574,7 @@ def fused_attn_fwd(
window_size=(-1, -1) if window_size is None else window_size, window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
) )
primitive = None primitive = None
...@@ -2583,6 +2696,7 @@ def fused_attn_bwd( ...@@ -2583,6 +2696,7 @@ def fused_attn_bwd(
window_size=(-1, -1) if window_size is None else window_size, window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
) )
primitive = None primitive = None
......
...@@ -18,6 +18,7 @@ from jax.interpreters import pxla ...@@ -18,6 +18,7 @@ from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import numpy as np
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
...@@ -201,6 +202,31 @@ def get_mesh_axis_rank(axis: str, mesh=None): ...@@ -201,6 +202,31 @@ def get_mesh_axis_rank(axis: str, mesh=None):
return jax.lax.axis_index(axis_name) return jax.lax.axis_index(axis_name)
def get_mesh_axis_rank_host(axis, mesh) -> int:
"""
Same as get_mesh_axis_rank(), but return a host value instead of a
traced device value.
"""
if axis not in mesh.axis_names:
raise ValueError(f"Axis {axis} not found in mesh axis names: {mesh.axis_names}")
axis_index = mesh.axis_names.index(axis)
# Convert mesh.devices (ndarray of Device objects) to flat list
devices = mesh.devices
local_device = jax.devices()[jax.process_index()] # Pick one device on this host
# Find index of local_device in mesh.devices
coords = np.argwhere(devices == local_device)
if coords.size == 0:
raise ValueError(f"Local device {local_device} not found in mesh.devices.")
coords = tuple(coords[0]) # Coordinates in the mesh array
# Get the mesh rank along the specified axis
rank = coords[axis_index]
return int(rank)
@dataclass @dataclass
class MeshResource: class MeshResource:
"""A data container for managing mesh resources in distributed training. """A data container for managing mesh resources in distributed training.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment