Commit 5da71bf3 authored by dongcl's avatar dongcl
Browse files

bug fix

parent c0d23a67
...@@ -202,16 +202,16 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -202,16 +202,16 @@ class CoreAdaptation(MegatronAdaptationABC):
HAS_FLUX = False HAS_FLUX = False
if HAS_FLUX: if HAS_FLUX:
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
# column_parallel_linear_init_wrapper, column_parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
# ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
row_parallel_linear_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
RowParallelLinearPatch.forward) ColumnParallelLinearPatch.forward)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
# row_parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
# RowParallelLinearPatch.forward)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
......
from .layers import ( from .layers import (
parallel_linear_init_wrapper, column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper,
ColumnParallelLinearPatch, ColumnParallelLinearPatch,
RowParallelLinearPatch, RowParallelLinearPatch,
vocab_parallel_embedding_forward, vocab_parallel_embedding_forward,
......
...@@ -796,8 +796,6 @@ def column_parallel_linear_init_wrapper(fn): ...@@ -796,8 +796,6 @@ def column_parallel_linear_init_wrapper(fn):
self.previous_flux_params = (None,) * 5 self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None self.bw_gemm_rs_op = None
else:
return wrapper return wrapper
...@@ -880,7 +878,7 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -880,7 +878,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
output_hidden_size = weight.size(0) output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel: if self.sequence_parallel:
current_fw_params = ( current_flux_params = (
sequence_len, sequence_len,
batch_size, batch_size,
input_hidden_size, input_hidden_size,
...@@ -891,7 +889,7 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -891,7 +889,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
if ( if (
self.fw_ag_gemm_op is None self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params or current_flux_params != self.previous_flux_params
) ):
self.fw_ag_gemm_op = flux.AGKernel( self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(), 1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
...@@ -908,7 +906,7 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -908,7 +906,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
self.bw_gemm_rs_op = flux.GemmRS( self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(), 1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size, sequence_len * batch_size * world_size,
input_hidden_size, input_hidden_size,
input_parallel.dtype, input_parallel.dtype,
input_parallel.dtype, input_parallel.dtype,
...@@ -916,7 +914,7 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -916,7 +914,7 @@ class ColumnParallelLinearPatch(torch.nn.Module):
fuse_reduction=False fuse_reduction=False
) )
self.previous_flux_params = current_fw_params self.previous_flux_params = current_flux_params
self._forward_impl = ag_linear self._forward_impl = ag_linear
elif not weight.requires_grad: elif not weight.requires_grad:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment