Commit 5da71bf3 authored by dongcl's avatar dongcl
Browse files

bug fix

parent c0d23a67
......@@ -202,16 +202,16 @@ class CoreAdaptation(MegatronAdaptationABC):
HAS_FLUX = False
if HAS_FLUX:
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
# column_parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
# ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
row_parallel_linear_init_wrapper,
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
column_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
# row_parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
# RowParallelLinearPatch.forward)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
......
from .layers import (
parallel_linear_init_wrapper,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper,
ColumnParallelLinearPatch,
RowParallelLinearPatch,
vocab_parallel_embedding_forward,
......
......@@ -796,8 +796,6 @@ def column_parallel_linear_init_wrapper(fn):
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
else:
return wrapper
......@@ -880,7 +878,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_fw_params = (
current_flux_params = (
sequence_len,
batch_size,
input_hidden_size,
......@@ -891,7 +889,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
if (
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
)
):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
......@@ -908,7 +906,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
......@@ -916,7 +914,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
fuse_reduction=False
)
self.previous_flux_params = current_fw_params
self.previous_flux_params = current_flux_params
self._forward_impl = ag_linear
elif not weight.requires_grad:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment