"tests/vscode:/vscode.git/clone" did not exist on "8b3260599a0e2fab382717e58860bc7184717fbf"
Unverified Commit 39c0e709 authored by wdykas's avatar wdykas Committed by GitHub
Browse files

Re Do symmetric memory merge request (#1682)



* re merge request
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add docstring
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>

---------
Signed-off-by: default avatarPeter Dykas <wdykas@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4e036c8c
...@@ -31,6 +31,13 @@ from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase ...@@ -31,6 +31,13 @@ from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -1260,6 +1267,152 @@ def gather_along_first_dim( ...@@ -1260,6 +1267,152 @@ def gather_along_first_dim(
return out, handle return out, handle
# Global cache to store symmetric memory tensors
symmetric_mem_cache = {}
def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group, tag=None):
"""
Gets or creates a symmetric memory tensor with specified properties.
Reuses cached tensors when available to avoid redundant creation and rendezvous operations.
Note: This function always returns a 1D tensor.
Parameters
----------
tensor_numel : int
Number of elements in the tensor.
tensor_dtype : torch.dtype
Data type of the tensor.
tensor_device : torch.device
Device on which to allocate the tensor.
tp_group : dist_group_type
Process group for rendezvous operation.
tag : Any, optional
Optional identifier to further distinguish tensors.
Returns
-------
torch.Tensor
A symmetric memory tensor with the specified properties.
"""
# Create a cache key based on tensor properties and group
cache_key = (tensor_numel, tensor_dtype, tensor_device, tp_group.group_name, tag)
# Check if we already have a symmetric memory tensor for this configuration
if cache_key not in symmetric_mem_cache:
# Create a new symmetric memory tensor if not in cache
msg = symm_mem.empty(
tensor_numel,
dtype=tensor_dtype,
device=tensor_device,
)
# Perform the rendezvous once for this tensor
symm_mem.rendezvous(msg, group=tp_group)
# Store in cache
symmetric_mem_cache[cache_key] = msg
else:
# Reuse the existing symmetric memory tensor
msg = symmetric_mem_cache[cache_key]
return msg
def symmetric_all_reduce(
inp: torch.Tensor,
tp_group: Optional[dist_group_type] = None,
async_op: bool = False,
all_reduce_type: str = "multimem_all_reduce",
):
"""
Performs an all-reduce operation across multiple processes using symmetric memory.
If the input tensor is already in the symmetric memory cache we can avoid copy
overheads by just directly using the input tensor for all reduce. Externally
created symmetric memory tensors not in the cache currently will not be able to
avoid the extra copies.
Parameters
----------
inp : torch.Tensor
The input tensor to be reduced. The operation is performed in-place.
tp_group : Optional[dist_group_type], default=None
The process group over which to perform the all-reduce operation.
If None, the default process group is used.
async_op : bool, default=False
Whether to perform the operation asynchronously.
Note: Currently only synchronous operations are supported for symmetric memory variants.
all_reduce_type : str, default="multimem_all_reduce"
The type of all-reduce implementation to use. Options include:
- "nccl": Standard PyTorch distributed all-reduce
- "multimem_all_reduce": multimem symmetric all-reduce
- "two_shot": Two-shot symmetric all-reduce
- "one_shot": One-shot symmetric all-reduce
Returns
-------
Tuple[torch.Tensor, Optional[torch.distributed.Work]]
- The first element is the input tensor with the all-reduce result.
- The second element is the async work handle if async_op=True,
otherwise None.
"""
assert async_op is False, "Async symmetric ops no supported yet"
assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch"
if get_distributed_world_size(tp_group) == 1:
return inp, None
if all_reduce_type == "nccl":
# Standard all-reduce implementation
handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
return inp, handle
all_reduce_impl = None
if all_reduce_type == "multimem_all_reduce":
all_reduce_impl = torch.ops.symm_mem.multimem_all_reduce_
elif all_reduce_type == "two_shot":
all_reduce_impl = torch.ops.symm_mem.two_shot_all_reduce_
elif all_reduce_type == "one_shot":
all_reduce_impl = torch.ops.symm_mem.one_shot_all_reduce
else:
raise TypeError(f"All reduce type {all_reduce_type} is not supported.")
group_name = tp_group.group_name
tensor_shape = inp.shape
tensor_numel = inp.numel()
tensor_dtype = inp.dtype
tensor_device = inp.device
input_id = id(inp)
is_cached = any(id(cached_tensor) == input_id for cached_tensor in symmetric_mem_cache.values())
# Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads.
if is_cached:
all_reduce_impl(
inp,
"sum",
group_name,
)
else:
# Get symmetric memory tensor. Build or retrieve from cache.
msg = get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group)
msg.copy_(inp.reshape(-1))
all_reduce_impl(
msg,
"sum",
group_name,
)
# Copy the result back to the input tensor
inp.copy_(msg.reshape(tensor_shape))
return inp, None
def allreduce( def allreduce(
inp: torch.Tensor, inp: torch.Tensor,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
......
...@@ -15,6 +15,7 @@ from torch.nn import init ...@@ -15,6 +15,7 @@ from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
get_workspace, get_workspace,
get_ub, get_ub,
...@@ -41,6 +42,7 @@ from ..distributed import ( ...@@ -41,6 +42,7 @@ from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
allreduce, allreduce,
symmetric_all_reduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
...@@ -120,6 +122,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -120,6 +122,7 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False, debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -445,7 +448,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -445,7 +448,10 @@ class _LayerNormLinear(torch.autograd.Function):
if sequence_parallel: if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel: elif tensor_parallel:
out, _ = allreduce(out, tp_group) if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
...@@ -896,6 +902,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -896,6 +902,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # debug None, # debug
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # symmetric_ar_type
) )
...@@ -985,6 +992,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -985,6 +992,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
""" """
def __init__( def __init__(
...@@ -1014,6 +1026,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1014,6 +1026,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
symmetric_ar_type: Optional[str] = None,
name: str = None, name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1030,6 +1043,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1030,6 +1043,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
self.name = name self.name = name
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
...@@ -1099,6 +1113,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1099,6 +1113,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
assert ub_name is not None, "Userbuffer name [string] is not set." assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name self.ub_name = ub_name
if self.symmetric_ar_type is not None:
assert torch_version() >= (
2,
7,
0,
), "Torch version must be at least 2.7 to use symmetric memory"
self.eps = eps self.eps = eps
layer_norm_weight = torch.nn.Parameter( layer_norm_weight = torch.nn.Parameter(
torch.empty(self.in_features, device=device, dtype=params_dtype) torch.empty(self.in_features, device=device, dtype=params_dtype)
...@@ -1433,6 +1454,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1433,6 +1454,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -16,6 +16,7 @@ from torch.nn import init ...@@ -16,6 +16,7 @@ from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
get_workspace, get_workspace,
_ub_communicators, _ub_communicators,
...@@ -47,6 +48,7 @@ from ..distributed import ( ...@@ -47,6 +48,7 @@ from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
allreduce, allreduce,
symmetric_all_reduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
use_reentrant_activation_recompute, use_reentrant_activation_recompute,
...@@ -191,6 +193,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -191,6 +193,7 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False, debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -590,7 +593,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -590,7 +593,12 @@ class _LayerNormMLP(torch.autograd.Function):
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel: elif set_parallel_mode and tensor_parallel:
fc2_out, _ = allreduce(fc2_out, tp_group) if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
fc2_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1]) fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
...@@ -1190,6 +1198,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1190,6 +1198,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # fsdp_group None, # fsdp_group
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug None, # debug
) )
...@@ -1287,6 +1296,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1287,6 +1296,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase. used for forward propogation and activation recompute phase.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
""" """
def __init__( def __init__(
...@@ -1319,6 +1333,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1319,6 +1333,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_overlap_rs_dgrad: bool = False, ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
symmetric_ar_type: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1337,6 +1352,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1337,6 +1352,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = ( self.gemm_gelu_fusion = (
...@@ -1376,6 +1392,13 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1376,6 +1392,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad
) )
if self.symmetric_ar_type is not None:
assert torch_version() >= (
2,
7,
0,
), "Torch version must be at least 2.7 to use symmetric memory"
# Initialize params in FP8 # Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
...@@ -1651,6 +1674,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1651,6 +1674,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
get_workspace, get_workspace,
get_ub, get_ub,
...@@ -39,6 +40,7 @@ from ..distributed import ( ...@@ -39,6 +40,7 @@ from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
get_distributed_world_size, get_distributed_world_size,
allreduce, allreduce,
symmetric_all_reduce,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled, is_fp8_activation_recompute_enabled,
...@@ -66,7 +68,6 @@ from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload ...@@ -66,7 +68,6 @@ from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -110,6 +111,7 @@ class _Linear(torch.autograd.Function): ...@@ -110,6 +111,7 @@ class _Linear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False, debug: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -387,7 +389,10 @@ class _Linear(torch.autograd.Function): ...@@ -387,7 +389,10 @@ class _Linear(torch.autograd.Function):
if sequence_parallel: if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel: elif tensor_parallel:
out, _ = allreduce(out, tp_group) if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
out = out.view(-1, *inp_shape[1:-1], out_features) out = out.view(-1, *inp_shape[1:-1], out_features)
...@@ -782,6 +787,7 @@ class _Linear(torch.autograd.Function): ...@@ -782,6 +787,7 @@ class _Linear(torch.autograd.Function):
None, # fsdp_group None, # fsdp_group
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug None, # debug
) )
...@@ -855,7 +861,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -855,7 +861,11 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
""" """
def __init__( def __init__(
...@@ -881,6 +891,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -881,6 +891,7 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -894,6 +905,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -894,6 +905,7 @@ class Linear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type
self.name = name self.name = name
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
...@@ -963,6 +975,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -963,6 +975,13 @@ class Linear(TransformerEngineBaseModule):
assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
self.ub_name = ub_name self.ub_name = ub_name
if self.symmetric_ar_type is not None:
assert torch_version() >= (
2,
7,
0,
), "Torch version must be at least 2.7 to use symmetric memory"
# Initialize params in FP8 # Initialize params in FP8
with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()
...@@ -1248,6 +1267,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1248,6 +1267,7 @@ class Linear(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.symmetric_ar_type,
debug, debug,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment