Commit ec7c8bc3 authored by dongcl's avatar dongcl
Browse files

replace te with flux when using flux

parent 138b70a2
......@@ -190,27 +190,17 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
if os.getenv("USE_FLUX_OVERLAP", 0):
import flux
from ..core.tensor_parallel import (
ColumnParallelLinearPatch,
RowParallelLinearPatch,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
column_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
row_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec",
MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TEColumnParallelLinear",
FluxColumnParallelLinear)
MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TERowParallelLinear",
FluxRowParallelLinear)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_training(self):
......
......@@ -3,7 +3,6 @@ from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
......@@ -17,6 +16,9 @@ from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
......@@ -79,13 +81,13 @@ def get_gpt_layer_with_flux_spec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear,
linear_q_proj=FluxColumnParallelLinear,
linear_q_down_proj=FluxColumnParallelLinear,
linear_q_up_proj=FluxColumnParallelLinear,
linear_kv_down_proj=FluxColumnParallelLinear,
linear_kv_up_proj=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
linear_proj=FluxRowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
),
......@@ -111,9 +113,9 @@ def get_gpt_layer_with_flux_spec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
......@@ -145,8 +147,8 @@ def get_mlp_module_flux_spec(
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
linear_fc1=FluxColumnParallelLinear,
linear_fc2=FluxRowParallelLinear,
),
)
else:
......
from .layers import (
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper,
ColumnParallelLinearPatch,
RowParallelLinearPatch,
FluxColumnParallelLinear,
FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
......@@ -5,9 +5,8 @@ from typing import Callable, List, Optional
try:
import flux
HAS_FLUX = True
except ImportError:
HAS_FLUX = False
raise ImportError("flux is NOT installed")
import torch
import torch.nn.functional as F
......@@ -39,6 +38,10 @@ from megatron.core.tensor_parallel.mappings import (
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
......@@ -218,6 +221,8 @@ class AGLinear(torch.autograd.Function):
output = output.view(sequence_len * world_size, batch_size, -1)
else:
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
......@@ -232,7 +237,7 @@ class AGLinear(torch.autograd.Function):
transpose_weight = ctx.transpose_weight
bw_gemm_rs_op = ctx.bw_gemm_rs_op
wgrad_compute = True
wgrad_compute = weight.requires_grad
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
......@@ -300,10 +305,14 @@ class AGLinear(torch.autograd.Function):
)
if not ctx.sequence_parallel and ctx.allreduce_dgrad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
if weight.requires_grad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
else:
grad_input = _reduce(grad_input)
return grad_input, None, None, None, None, None, None, None, None, None, None
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
......@@ -530,7 +539,6 @@ class LinearRS(torch.autograd.Function):
# output = _reduce_scatter_along_first_dim(output)
else:
output = torch.matmul(input, weight.t())
output = _reduce(output)
return output
......@@ -545,7 +553,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight = ctx.transpose_weight
bw_ag_gemm_op = ctx.bw_ag_gemm_op
wgrad_compute = True
wgrad_compute = weight.requires_grad
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
......@@ -604,6 +612,9 @@ class LinearRS(torch.autograd.Function):
else:
grad_input = grad_output.matmul(weight)
if not weight.requires_grad:
grad_input, None, None, None, None, None, None, None, None, None, None
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
......@@ -772,39 +783,99 @@ def linear_rs(
linear_rs.warned = False
def column_parallel_linear_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# flux params
self.use_flux = False
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
return wrapper
class ColumnParallelLinearPatch(torch.nn.Module):
class FluxColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias
gather_output:
If true, call all-gather on output and make Y available to all GPUs,
otherwise, every GPU will have its output which is Y_i = XA_i
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It
returns the master weights used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
skip_weight_param_allocation:
If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note that this does not
affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer:
This buffer holds the input activations of the final embedding
linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer:
This buffer holds the gradient outputs of the final embedding linear
layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert:
If True, the layer is treated as an MoE expert layer.
config:
ModelParallelConfig object
tp_comm_buffer_name:
Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce:
If True, reduction of output gradients across tensor-parallel ranks
will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to
delay and fuse reduction along with other gradients for performance optimization.
"""
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
):
super(FluxColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
gather_output=gather_output,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
embedding_activation_buffer=embedding_activation_buffer,
grad_output_buffer=grad_output_buffer,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce,
)
# flux params
self._forward_impl = ag_linear
self.flux_transpose_weight = getattr(self.config, "flux_transpose_weight", False)
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
def forward(
self,
input_: torch.Tensor,
......@@ -867,78 +938,64 @@ class ColumnParallelLinearPatch(torch.nn.Module):
):
self.embedding_activation_buffer.append(input_parallel)
# Matrix multiply.
if self.use_flux:
assert HAS_FLUX, "flux is NOT installed"
# flux kernels.
if self.sequence_parallel:
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
if (
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_parallel.dtype
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if (
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.previous_flux_params = current_flux_params
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self._forward_impl = ag_linear
elif not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
self.previous_flux_params = current_flux_params
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
forward_params = {
"input": input_parallel,
"weight": weight,
"bias": bias,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if self.explicit_expert_comm else self.sequence_parallel,
"grad_output_buffer": self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None,
"wgrad_deferral_limit": self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
}
if self.use_flux:
forward_params.update({
"transpose_weight": self.flux_transpose_weight,
"fw_ag_gemm_op": self.fw_ag_gemm_op,
"bw_gemm_rs_op": self.bw_gemm_rs_op,
})
output_parallel = self._forward_impl(**forward_params)
output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None,
wgrad_deferral_limit=self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
transpose_weight=self.flux_transpose_weight,
fw_ag_gemm_op=self.fw_ag_gemm_op,
bw_gemm_rs_op=self.bw_gemm_rs_op
)
gather_output = self.gather_output
# Use the runtime gather output if it's set explicitly.
......@@ -955,38 +1012,79 @@ class ColumnParallelLinearPatch(torch.nn.Module):
return output, output_bias
def row_parallel_linear_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
class FluxRowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
# flux params
self.use_flux = False
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias. Note that bias is not parallelized.
input_is_parallel:
If true, we assume that the input is already split across the GPUs
and we do not split again.
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It returns the master weights
used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
is_expert:
If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name:
Communication buffer name. Not used in non-Transformer-Engine modules.
config:
ModelParallelConfig object
return wrapper
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
):
class RowParallelLinearPatch(torch.nn.Module):
"""Linear layer with row parallelism.
super(FluxRowParallelLinear, self)__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
input_is_parallel=input_is_parallel,
skip_bias_add=skip_bias_add,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name
)
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
# flux params
self._forward_impl = linear_rs
self.flux_transpose_weight = getattr(self.config, "flux_transpose_weight", False)
self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None
"""
def forward(self, input_):
"""Forward of RowParallelLinear
......@@ -1011,93 +1109,77 @@ class RowParallelLinearPatch(torch.nn.Module):
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if self.use_flux:
assert HAS_FLUX, "flux is NOT installed"
# flux kernels
if self.sequence_parallel:
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = self.weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
if (
self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params
):
self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.previous_flux_params = current_flux_params
self._forward_impl = linear_rs
elif not self.weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False
forward_params = {
"input": input_parallel,
"weight": self.weight,
"bias": self.bias if self.use_flux or not self.skip_bias_add else None,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if not self.use_flux else self.sequence_parallel,
"grad_output_buffer": None,
}
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
if self.use_flux:
forward_params.update({
"transpose_weight": self.flux_transpose_weight,
"fw_gemm_rs_op": self.fw_gemm_rs_op,
"bw_ag_gemm_op": self.bw_ag_gemm_op,
})
if (
self.fw_gemm_rs_op is None
or current_flux_params != self.previous_flux_params
):
self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
output_parallel = self._forward_impl(**forward_params)
self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if self.use_flux:
return output_parallel, None if not self.skip_bias_add else self.bias
self.previous_flux_params = current_flux_params
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=self.bias if not self.skip_bias_add and self.sequence_parallel else None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=False,
sequence_parallel=False if explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=None,
transpose_weight=self.flux_transpose_weight,
fw_gemm_rs_op=self.fw_gemm_rs_op,
bw_ag_gemm_op=self.bw_ag_gemm_op
)
# All-reduce across all the partitions.
if self.explicit_expert_comm:
assert self.skip_bias_add
output_ = output_parallel
elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
output_ = output_parallel
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None
if not self.sequence_parallel:
output = (output_ + self.bias) if self.bias is not None else output_
else:
output = output_
output_bias = self.bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment