Commit ec7c8bc3 authored by dongcl's avatar dongcl
Browse files

replace te with flux when using flux

parent 138b70a2
...@@ -190,27 +190,17 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -190,27 +190,17 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux # flux
if os.getenv("USE_FLUX_OVERLAP", 0): if os.getenv("USE_FLUX_OVERLAP", 0):
import flux
from ..core.tensor_parallel import ( from ..core.tensor_parallel import (
ColumnParallelLinearPatch, FluxColumnParallelLinear,
RowParallelLinearPatch, FluxRowParallelLinear
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
) )
from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__", MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TEColumnParallelLinear",
column_parallel_linear_init_wrapper, FluxColumnParallelLinear)
apply_wrapper=True) MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TERowParallelLinear",
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward", FluxRowParallelLinear)
ColumnParallelLinearPatch.forward) MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
row_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec",
get_gpt_layer_with_flux_spec) get_gpt_layer_with_flux_spec)
def patch_training(self): def patch_training(self):
......
...@@ -3,7 +3,6 @@ from typing import Optional ...@@ -3,7 +3,6 @@ from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.identity_op import IdentityOp
...@@ -17,6 +16,9 @@ from megatron.core.transformer.transformer_layer import ( ...@@ -17,6 +16,9 @@ from megatron.core.transformer.transformer_layer import (
TransformerLayer, TransformerLayer,
TransformerLayerSubmodules, TransformerLayerSubmodules,
) )
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version from megatron.core.utils import is_te_min_version
try: try:
...@@ -79,13 +81,13 @@ def get_gpt_layer_with_flux_spec( ...@@ -79,13 +81,13 @@ def get_gpt_layer_with_flux_spec(
module=MLASelfAttention, module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal}, params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules( submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear, linear_q_proj=FluxColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear, linear_q_down_proj=FluxColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear, linear_q_up_proj=FluxColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear, linear_kv_down_proj=FluxColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear, linear_kv_up_proj=FluxColumnParallelLinear,
core_attention=TEDotProductAttention, core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear, linear_proj=FluxRowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp, q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp, kv_layernorm=TENorm if qk_layernorm else IdentityOp,
), ),
...@@ -111,9 +113,9 @@ def get_gpt_layer_with_flux_spec( ...@@ -111,9 +113,9 @@ def get_gpt_layer_with_flux_spec(
module=SelfAttention, module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal}, params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules( submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear, linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention, core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear, linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp, q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp, k_layernorm=qk_norm if qk_layernorm else IdentityOp,
), ),
...@@ -145,8 +147,8 @@ def get_mlp_module_flux_spec( ...@@ -145,8 +147,8 @@ def get_mlp_module_flux_spec(
return ModuleSpec( return ModuleSpec(
module=MLP, module=MLP,
submodules=MLPSubmodules( submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc1=FluxColumnParallelLinear,
linear_fc2=RowParallelLinear, linear_fc2=FluxRowParallelLinear,
), ),
) )
else: else:
......
from .layers import ( from .layers import (
column_parallel_linear_init_wrapper, FluxColumnParallelLinear,
row_parallel_linear_init_wrapper, FluxRowParallelLinear,
ColumnParallelLinearPatch,
RowParallelLinearPatch,
vocab_parallel_embedding_forward, vocab_parallel_embedding_forward,
vocab_parallel_embedding_init, vocab_parallel_embedding_init,
) )
\ No newline at end of file
...@@ -5,9 +5,8 @@ from typing import Callable, List, Optional ...@@ -5,9 +5,8 @@ from typing import Callable, List, Optional
try: try:
import flux import flux
HAS_FLUX = True
except ImportError: except ImportError:
HAS_FLUX = False raise ImportError("flux is NOT installed")
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -39,6 +38,10 @@ from megatron.core.tensor_parallel.mappings import ( ...@@ -39,6 +38,10 @@ from megatron.core.tensor_parallel.mappings import (
) )
from megatron.core.tensor_parallel.utils import VocabUtility from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
from megatron.core.tensor_parallel.layers import ( from megatron.core.tensor_parallel.layers import (
custom_fwd, custom_fwd,
custom_bwd, custom_bwd,
...@@ -218,6 +221,8 @@ class AGLinear(torch.autograd.Function): ...@@ -218,6 +221,8 @@ class AGLinear(torch.autograd.Function):
output = output.view(sequence_len * world_size, batch_size, -1) output = output.view(sequence_len * world_size, batch_size, -1)
else: else:
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output return output
...@@ -232,7 +237,7 @@ class AGLinear(torch.autograd.Function): ...@@ -232,7 +237,7 @@ class AGLinear(torch.autograd.Function):
transpose_weight = ctx.transpose_weight transpose_weight = ctx.transpose_weight
bw_gemm_rs_op = ctx.bw_gemm_rs_op bw_gemm_rs_op = ctx.bw_gemm_rs_op
wgrad_compute = True wgrad_compute = weight.requires_grad
if grad_output_buffer is not None: if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit: if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output) grad_output_buffer.append(grad_output)
...@@ -300,10 +305,14 @@ class AGLinear(torch.autograd.Function): ...@@ -300,10 +305,14 @@ class AGLinear(torch.autograd.Function):
) )
if not ctx.sequence_parallel and ctx.allreduce_dgrad: if not ctx.sequence_parallel and ctx.allreduce_dgrad:
if weight.requires_grad:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True grad_input, group=get_tensor_model_parallel_group(), async_op=True
) )
else:
grad_input = _reduce(grad_input)
return grad_input, None, None, None, None, None, None, None, None, None, None
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
if wgrad_compute: if wgrad_compute:
...@@ -530,7 +539,6 @@ class LinearRS(torch.autograd.Function): ...@@ -530,7 +539,6 @@ class LinearRS(torch.autograd.Function):
# output = _reduce_scatter_along_first_dim(output) # output = _reduce_scatter_along_first_dim(output)
else: else:
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
output = _reduce(output)
return output return output
...@@ -545,7 +553,7 @@ class LinearRS(torch.autograd.Function): ...@@ -545,7 +553,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight = ctx.transpose_weight transpose_weight = ctx.transpose_weight
bw_ag_gemm_op = ctx.bw_ag_gemm_op bw_ag_gemm_op = ctx.bw_ag_gemm_op
wgrad_compute = True wgrad_compute = weight.requires_grad
if grad_output_buffer is not None: if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit: if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output) grad_output_buffer.append(grad_output)
...@@ -604,6 +612,9 @@ class LinearRS(torch.autograd.Function): ...@@ -604,6 +612,9 @@ class LinearRS(torch.autograd.Function):
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
if not weight.requires_grad:
grad_input, None, None, None, None, None, None, None, None, None, None
if ctx.sequence_parallel and wgrad_compute: if ctx.sequence_parallel and wgrad_compute:
handle.wait() handle.wait()
...@@ -772,39 +783,99 @@ def linear_rs( ...@@ -772,39 +783,99 @@ def linear_rs(
linear_rs.warned = False linear_rs.warned = False
def column_parallel_linear_init_wrapper(fn): class FluxColumnParallelLinear(ColumnParallelLinear):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# flux params
self.use_flux = False
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
return wrapper
class ColumnParallelLinearPatch(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p]. its second dimension as A = [A_1, ..., A_p].
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias
gather_output:
If true, call all-gather on output and make Y available to all GPUs,
otherwise, every GPU will have its output which is Y_i = XA_i
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It
returns the master weights used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
skip_weight_param_allocation:
If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note that this does not
affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer:
This buffer holds the input activations of the final embedding
linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer:
This buffer holds the gradient outputs of the final embedding linear
layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert:
If True, the layer is treated as an MoE expert layer.
config:
ModelParallelConfig object
tp_comm_buffer_name:
Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce:
If True, reduction of output gradients across tensor-parallel ranks
will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to
delay and fuse reduction along with other gradients for performance optimization.
""" """
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
):
super(FluxColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
gather_output=gather_output,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
embedding_activation_buffer=embedding_activation_buffer,
grad_output_buffer=grad_output_buffer,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce,
)
# flux params
self._forward_impl = ag_linear
self.flux_transpose_weight = getattr(self.config, "flux_transpose_weight", False)
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
def forward( def forward(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
...@@ -867,14 +938,11 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -867,14 +938,11 @@ class ColumnParallelLinearPatch(torch.nn.Module):
): ):
self.embedding_activation_buffer.append(input_parallel) self.embedding_activation_buffer.append(input_parallel)
# Matrix multiply. # flux kernels.
if self.use_flux: if self.sequence_parallel:
assert HAS_FLUX, "flux is NOT installed"
sequence_len, batch_size, input_hidden_size = input_parallel.size() sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0) output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_flux_params = ( current_flux_params = (
sequence_len, sequence_len,
batch_size, batch_size,
...@@ -913,32 +981,21 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -913,32 +981,21 @@ class ColumnParallelLinearPatch(torch.nn.Module):
self.previous_flux_params = current_flux_params self.previous_flux_params = current_flux_params
self._forward_impl = ag_linear
elif not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
forward_params = { output_parallel = self._forward_impl(
"input": input_parallel, input=input_parallel,
"weight": weight, weight=weight,
"bias": bias, bias=bias,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion, gradient_accumulation_fusion=self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad, allreduce_dgrad=allreduce_dgrad,
"sequence_parallel": False if self.explicit_expert_comm else self.sequence_parallel, sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
"grad_output_buffer": self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None, grad_output_buffer=self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None,
"wgrad_deferral_limit": self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None, wgrad_deferral_limit=self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
} transpose_weight=self.flux_transpose_weight,
if self.use_flux: fw_ag_gemm_op=self.fw_ag_gemm_op,
forward_params.update({ bw_gemm_rs_op=self.bw_gemm_rs_op
"transpose_weight": self.flux_transpose_weight, )
"fw_ag_gemm_op": self.fw_ag_gemm_op,
"bw_gemm_rs_op": self.bw_gemm_rs_op,
})
output_parallel = self._forward_impl(**forward_params)
gather_output = self.gather_output gather_output = self.gather_output
# Use the runtime gather output if it's set explicitly. # Use the runtime gather output if it's set explicitly.
...@@ -955,38 +1012,79 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -955,38 +1012,79 @@ class ColumnParallelLinearPatch(torch.nn.Module):
return output, output_bias return output, output_bias
def row_parallel_linear_init_wrapper(fn): class FluxRowParallelLinear(torch.nn.Module):
@wraps(fn) """Linear layer with row parallelism.
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# flux params The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
self.use_flux = False along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
self.previous_flux_params = (None,) * 5 Args:
self.fw_gemm_rs_op = None input_size:
self.bw_ag_gemm_op = None first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias. Note that bias is not parallelized.
input_is_parallel:
If true, we assume that the input is already split across the GPUs
and we do not split again.
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It returns the master weights
used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
is_expert:
If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name:
Communication buffer name. Not used in non-Transformer-Engine modules.
config:
ModelParallelConfig object
return wrapper """
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
):
class RowParallelLinearPatch(torch.nn.Module): super(FluxRowParallelLinear, self)__init__(
"""Linear layer with row parallelism. input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
input_is_parallel=input_is_parallel,
skip_bias_add=skip_bias_add,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name
)
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X # flux params
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p] self._forward_impl = linear_rs
self.flux_transpose_weight = getattr(self.config, "flux_transpose_weight", False)
self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None
"""
def forward(self, input_): def forward(self, input_):
"""Forward of RowParallelLinear """Forward of RowParallelLinear
...@@ -1011,14 +1109,14 @@ class RowParallelLinearPatch(torch.nn.Module): ...@@ -1011,14 +1109,14 @@ class RowParallelLinearPatch(torch.nn.Module):
else: else:
assert not self.sequence_parallel assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if self.use_flux:
assert HAS_FLUX, "flux is NOT installed"
# flux kernels
if self.sequence_parallel:
sequence_len, batch_size, input_hidden_size = input_parallel.size() sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = self.weight.size(0) output_hidden_size = self.weight.size(0)
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_flux_params = ( current_flux_params = (
sequence_len, sequence_len,
batch_size, batch_size,
...@@ -1057,47 +1155,31 @@ class RowParallelLinearPatch(torch.nn.Module): ...@@ -1057,47 +1155,31 @@ class RowParallelLinearPatch(torch.nn.Module):
self.previous_flux_params = current_flux_params self.previous_flux_params = current_flux_params
self._forward_impl = linear_rs output_parallel = self._forward_impl(
elif not self.weight.requires_grad: input=input_parallel,
self._forward_impl = linear_with_frozen_weight weight=self.weight,
else: bias=self.bias if not self.skip_bias_add and self.sequence_parallel else None,
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=False,
allreduce_dgrad = False sequence_parallel=False if explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=None,
forward_params = { transpose_weight=self.flux_transpose_weight,
"input": input_parallel, fw_gemm_rs_op=self.fw_gemm_rs_op,
"weight": self.weight, bw_ag_gemm_op=self.bw_ag_gemm_op
"bias": self.bias if self.use_flux or not self.skip_bias_add else None, )
"gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if not self.use_flux else self.sequence_parallel,
"grad_output_buffer": None,
}
if self.use_flux:
forward_params.update({
"transpose_weight": self.flux_transpose_weight,
"fw_gemm_rs_op": self.fw_gemm_rs_op,
"bw_ag_gemm_op": self.bw_ag_gemm_op,
})
output_parallel = self._forward_impl(**forward_params)
if self.use_flux:
return output_parallel, None if not self.skip_bias_add else self.bias
# All-reduce across all the partitions.
if self.explicit_expert_comm: if self.explicit_expert_comm:
assert self.skip_bias_add assert self.skip_bias_add
output_ = output_parallel output_ = output_parallel
elif self.sequence_parallel: elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) output_ = output_parallel
else: else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None output_bias = None
if not self.sequence_parallel:
output = (output_ + self.bias) if self.bias is not None else output_
else: else:
output = output_ output = output_
output_bias = self.bias output_bias = self.bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment