Commit 138b70a2 authored by dongcl's avatar dongcl
Browse files

fix flux bug

parent 72aeb0f3
......@@ -91,7 +91,7 @@ def get_gpt_layer_with_flux_spec(
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
......@@ -119,7 +119,7 @@ def get_gpt_layer_with_flux_spec(
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
......
......@@ -213,12 +213,12 @@ class AGLinear(torch.autograd.Function):
output_scale=None,
fast_accum=False
)
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len * world_size, batch_size, -1)
else:
output = torch.matmul(input, weight.t())
torch.cuda.current_stream().synchronize()
return output
@staticmethod
......@@ -260,31 +260,34 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel:
sequence_len, batch_size, _ = grad_output.size()
if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False
)
torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
# if bw_gemm_rs_op is None:
# input_hidden_size = weight.size(-1)
# bw_gemm_rs_op = flux.GemmRS(
# get_tensor_model_parallel_group(),
# 1, # world_size // torch.cuda.device_count(),
# sequence_len * batch_size,
# input_hidden_size,
# input.dtype,
# input.dtype,
# transpose_weight=transpose_weight,
# fuse_reduction=False
# )
# grad_input = bw_gemm_rs_op.forward(
# grad_output.view(sequence_len * batch_size, -1),
# weight if transpose_weight else weight.t().contiguous(),
# bias=None,
# input_scale=None,
# weight_scale=None,
# output_scale=None,
# fast_accum=False
# )
# torch.distributed.barrier()
# torch.cuda.current_stream().synchronize()
# grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
grad_input = grad_output.matmul(weight)
grad_input = _reduce_scatter_along_first_dim(grad_input)
else:
grad_input = grad_output.matmul(weight)
......@@ -496,8 +499,6 @@ class LinearRS(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size()
sequence_len, batch_size, _ = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(sequence_len * batch_size, -1)
output_hidden_size = weight.size(0)
if sequence_parallel:
......@@ -513,7 +514,7 @@ class LinearRS(torch.autograd.Function):
fuse_reduction=False,
)
output = fw_gemm_rs_op.forward(
input,
input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight,
bias=bias,
input_scale=None,
......@@ -521,12 +522,16 @@ class LinearRS(torch.autograd.Function):
output_scale=None,
fast_accum=False,
)
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# output = _reduce_scatter_along_first_dim(output)
else:
output = torch.matmul(input, weight.t())
output = _reduce(output)
# torch.cuda.current_stream().synchronize()
return output
@staticmethod
......@@ -785,10 +790,9 @@ def column_parallel_linear_init_wrapper(fn):
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
if self.sequence_parallel:
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
return wrapper
......@@ -969,10 +973,9 @@ def row_parallel_linear_init_wrapper(fn):
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
if self.sequence_parallel:
self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None
self.previous_flux_params = (None,) * 5
self.fw_gemm_rs_op = None
self.bw_ag_gemm_op = None
return wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment