Commit 390eac88 authored by dongcl's avatar dongcl
Browse files

bug fix

parent ec7c8bc3
......@@ -1012,7 +1012,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
return output, output_bias
class FluxRowParallelLinear(torch.nn.Module):
class FluxRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
......@@ -1064,7 +1064,7 @@ class FluxRowParallelLinear(torch.nn.Module):
tp_comm_buffer_name: str = None, # Not used
):
super(FluxRowParallelLinear, self)__init__(
super(FluxRowParallelLinear, self).__init__(
input_size=input_size,
output_size=output_size,
config=config,
......@@ -1161,7 +1161,7 @@ class FluxRowParallelLinear(torch.nn.Module):
bias=self.bias if not self.skip_bias_add and self.sequence_parallel else None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
allreduce_dgrad=False,
sequence_parallel=False if explicit_expert_comm else self.sequence_parallel,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=None,
transpose_weight=self.flux_transpose_weight,
fw_gemm_rs_op=self.fw_gemm_rs_op,
......
......@@ -30,7 +30,7 @@ except ImportError:
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False, use_flux=False):
def get_mtp_spec(transformer_layer, use_te=False):
"""
Multi Token Predication Layer Specification.
"""
......@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False, use_flux=False):
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te or use_flux else LNImpl,
hnorm=TENorm if use_te or use_flux else LNImpl,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te or use_flux else LNImpl,
final_layernorm=TENorm if use_te else LNImpl,
output_layer=None,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment