Commit 138b70a2 authored by dongcl's avatar dongcl
Browse files

fix flux bug

parent 72aeb0f3
...@@ -91,7 +91,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -91,7 +91,7 @@ def get_gpt_layer_with_flux_spec(
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp, pre_mlp_layernorm=TENorm,
mlp=mlp, mlp=mlp,
mlp_bda=get_bias_dropout_add, mlp_bda=get_bias_dropout_add,
), ),
...@@ -119,7 +119,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -119,7 +119,7 @@ def get_gpt_layer_with_flux_spec(
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp, pre_mlp_layernorm=TENorm,
mlp=mlp, mlp=mlp,
mlp_bda=get_bias_dropout_add, mlp_bda=get_bias_dropout_add,
), ),
......
...@@ -213,12 +213,12 @@ class AGLinear(torch.autograd.Function): ...@@ -213,12 +213,12 @@ class AGLinear(torch.autograd.Function):
output_scale=None, output_scale=None,
fast_accum=False fast_accum=False
) )
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len * world_size, batch_size, -1) output = output.view(sequence_len * world_size, batch_size, -1)
else: else:
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
torch.cuda.current_stream().synchronize()
return output return output
@staticmethod @staticmethod
...@@ -260,31 +260,34 @@ class AGLinear(torch.autograd.Function): ...@@ -260,31 +260,34 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel: if ctx.sequence_parallel:
sequence_len, batch_size, _ = grad_output.size() sequence_len, batch_size, _ = grad_output.size()
if bw_gemm_rs_op is None: # if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1) # input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS( # bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), # get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(), # 1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size, # sequence_len * batch_size,
input_hidden_size, # input_hidden_size,
input.dtype, # input.dtype,
input.dtype, # input.dtype,
transpose_weight=transpose_weight, # transpose_weight=transpose_weight,
fuse_reduction=False # fuse_reduction=False
) # )
grad_input = bw_gemm_rs_op.forward( # grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1), # grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(), # weight if transpose_weight else weight.t().contiguous(),
bias=None, # bias=None,
input_scale=None, # input_scale=None,
weight_scale=None, # weight_scale=None,
output_scale=None, # output_scale=None,
fast_accum=False # fast_accum=False
) # )
torch.cuda.current_stream().synchronize() # torch.distributed.barrier()
grad_input = grad_input.view(sequence_len // world_size, batch_size, -1) # torch.cuda.current_stream().synchronize()
# grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
grad_input = grad_output.matmul(weight)
grad_input = _reduce_scatter_along_first_dim(grad_input)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
...@@ -496,8 +499,6 @@ class LinearRS(torch.autograd.Function): ...@@ -496,8 +499,6 @@ class LinearRS(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
sequence_len, batch_size, _ = input.size() sequence_len, batch_size, _ = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(sequence_len * batch_size, -1)
output_hidden_size = weight.size(0) output_hidden_size = weight.size(0)
if sequence_parallel: if sequence_parallel:
...@@ -513,7 +514,7 @@ class LinearRS(torch.autograd.Function): ...@@ -513,7 +514,7 @@ class LinearRS(torch.autograd.Function):
fuse_reduction=False, fuse_reduction=False,
) )
output = fw_gemm_rs_op.forward( output = fw_gemm_rs_op.forward(
input, input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight, weight.t().contiguous() if transpose_weight else weight,
bias=bias, bias=bias,
input_scale=None, input_scale=None,
...@@ -521,12 +522,16 @@ class LinearRS(torch.autograd.Function): ...@@ -521,12 +522,16 @@ class LinearRS(torch.autograd.Function):
output_scale=None, output_scale=None,
fast_accum=False, fast_accum=False,
) )
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len // world_size, batch_size, -1) output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# output = _reduce_scatter_along_first_dim(output)
else: else:
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
output = _reduce(output) output = _reduce(output)
# torch.cuda.current_stream().synchronize()
return output return output
@staticmethod @staticmethod
...@@ -785,7 +790,6 @@ def column_parallel_linear_init_wrapper(fn): ...@@ -785,7 +790,6 @@ def column_parallel_linear_init_wrapper(fn):
elif hasattr(self.config, "flux_transpose_weight"): elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight self.flux_transpose_weight = self.config.flux_transpose_weight
if self.sequence_parallel:
self.previous_flux_params = (None,) * 5 self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None self.bw_gemm_rs_op = None
...@@ -969,7 +973,6 @@ def row_parallel_linear_init_wrapper(fn): ...@@ -969,7 +973,6 @@ def row_parallel_linear_init_wrapper(fn):
elif hasattr(self.config, "flux_transpose_weight"): elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight self.flux_transpose_weight = self.config.flux_transpose_weight
if self.sequence_parallel:
self.previous_flux_params = (None,) * 5 self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None self.bw_ag_gemm_op = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment