Commit 53627040 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 409cdfef
......@@ -207,11 +207,11 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
# row_parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
# RowParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
row_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
......
......@@ -546,7 +546,9 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad:
world_size = get_tensor_model_parallel_world_size()
if wgrad_compute:
if ctx.sequence_parallel:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * world_size
......@@ -565,12 +567,10 @@ class LinearRS(torch.autograd.Function):
total_grad_output = grad_output
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
if bw_gemm_rs_op is None:
if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, #world_size // torch.cuda.device_count(),
......@@ -1013,7 +1013,7 @@ class RowParallelLinearPatch(torch.nn.Module):
assert HAS_FLUX, "flux is NOT installed"
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0)
output_hidden_size = self.weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_flux_params = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment