Commit 53627040 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 409cdfef
...@@ -207,11 +207,11 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -207,11 +207,11 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward) ColumnParallelLinearPatch.forward)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
# row_parallel_linear_init_wrapper, row_parallel_linear_init_wrapper,
# apply_wrapper=True) apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
# RowParallelLinearPatch.forward) RowParallelLinearPatch.forward)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
......
...@@ -546,7 +546,9 @@ class LinearRS(torch.autograd.Function): ...@@ -546,7 +546,9 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer.append(grad_output) grad_output_buffer.append(grad_output)
wgrad_compute = False wgrad_compute = False
if wgrad: world_size = get_tensor_model_parallel_world_size()
if wgrad_compute:
if ctx.sequence_parallel: if ctx.sequence_parallel:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
...@@ -565,12 +567,10 @@ class LinearRS(torch.autograd.Function): ...@@ -565,12 +567,10 @@ class LinearRS(torch.autograd.Function):
total_grad_output = grad_output total_grad_output = grad_output
if ctx.sequence_parallel: if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
sequence_len, batch_size, output_hidden_size = grad_output.size() sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
if bw_gemm_rs_op is None: if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel( bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, #world_size // torch.cuda.device_count(), 1, #world_size // torch.cuda.device_count(),
...@@ -1013,7 +1013,7 @@ class RowParallelLinearPatch(torch.nn.Module): ...@@ -1013,7 +1013,7 @@ class RowParallelLinearPatch(torch.nn.Module):
assert HAS_FLUX, "flux is NOT installed" assert HAS_FLUX, "flux is NOT installed"
sequence_len, batch_size, input_hidden_size = input_parallel.size() sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0) output_hidden_size = self.weight.size(0)
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel: if self.sequence_parallel:
current_flux_params = ( current_flux_params = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment