Commit 390eac88 authored by dongcl's avatar dongcl
Browse files

bug fix

parent ec7c8bc3
...@@ -1012,7 +1012,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -1012,7 +1012,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
return output, output_bias return output, output_bias
class FluxRowParallelLinear(torch.nn.Module): class FluxRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
...@@ -1064,7 +1064,7 @@ class FluxRowParallelLinear(torch.nn.Module): ...@@ -1064,7 +1064,7 @@ class FluxRowParallelLinear(torch.nn.Module):
tp_comm_buffer_name: str = None, # Not used tp_comm_buffer_name: str = None, # Not used
): ):
super(FluxRowParallelLinear, self)__init__( super(FluxRowParallelLinear, self).__init__(
input_size=input_size, input_size=input_size,
output_size=output_size, output_size=output_size,
config=config, config=config,
...@@ -1161,7 +1161,7 @@ class FluxRowParallelLinear(torch.nn.Module): ...@@ -1161,7 +1161,7 @@ class FluxRowParallelLinear(torch.nn.Module):
bias=self.bias if not self.skip_bias_add and self.sequence_parallel else None, bias=self.bias if not self.skip_bias_add and self.sequence_parallel else None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion, gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=False, allreduce_dgrad=False,
sequence_parallel=False if explicit_expert_comm else self.sequence_parallel, sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=None, grad_output_buffer=None,
transpose_weight=self.flux_transpose_weight, transpose_weight=self.flux_transpose_weight,
fw_gemm_rs_op=self.fw_gemm_rs_op, fw_gemm_rs_op=self.fw_gemm_rs_op,
......
...@@ -30,7 +30,7 @@ except ImportError: ...@@ -30,7 +30,7 @@ except ImportError:
LNImpl = WrappedTorchNorm LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False, use_flux=False): def get_mtp_spec(transformer_layer, use_te=False):
""" """
Multi Token Predication Layer Specification. Multi Token Predication Layer Specification.
""" """
...@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False, use_flux=False): ...@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False, use_flux=False):
module=MultiTokenPredictor, module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules( submodules=MultiTokenPredicationSubmodules(
embedding=None, embedding=None,
enorm=TENorm if use_te or use_flux else LNImpl, enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te or use_flux else LNImpl, hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear, eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer, transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te or use_flux else LNImpl, final_layernorm=TENorm if use_te else LNImpl,
output_layer=None, output_layer=None,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment