Commit f25c421e authored by dongcl's avatar dongcl
Browse files

fix flux bug

parent b6eb1484
......@@ -143,9 +143,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......@@ -197,18 +197,17 @@ class CoreAdaptation(MegatronAdaptationABC):
HAS_FLUX = False
if HAS_FLUX:
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
# parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
# ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......
......@@ -13,8 +13,11 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
......@@ -31,12 +34,15 @@ from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
......@@ -176,26 +182,24 @@ class AGLinear(torch.autograd.Function):
ctx.grad_output_buffer = grad_output_buffer
ctx.transpose_weight = transpose_weight
sequence_len = input.size(0)
sequence_len, batch_size, input_hidden_size = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(
input.shape[0] * input.shape[1], input.shape[2]
sequence_len * batch_size, input_hidden_size
)
M, K = list(input.size())
N = weight.size(0)
M = M * get_tensor_model_parallel_world_size()
output_hidden_size = weight.size(0)
if transpose_weight:
weight = weight.t().contiguous()
if sequence_parallel:
sequence_len = sequence_len * get_tensor_model_parallel_world_size()
ag_gemm_kernel = flux.AGKernel(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_world_size() // torch.cuda.device_count(),
M,
N,
K,
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input_hidden_size,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
......@@ -206,13 +210,13 @@ class AGLinear(torch.autograd.Function):
input,
weight,
bias=bias,
input_scale=input_scale,
weight_scale=weight_scale,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False
)
else:
output_buf = torch.empty([M, N], dtype=input.dtype, device=input.device)
output_buf = torch.empty([sequence_len * batch_size, output_hidden_size], dtype=input.dtype, device=input.device)
gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype,
output_dtype=input.dtype,
......@@ -231,7 +235,7 @@ class AGLinear(torch.autograd.Function):
)
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, input.size(0) // sequence_len, -1)
output = output.view(sequence_len, batch_size, -1)
return output
......@@ -272,20 +276,17 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel:
sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output = grad_output.view(
sequence_len * batch_size, output_hidden_size
)
if not transpose_weight:
weight = weight.t().contiguous()
gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
......@@ -293,7 +294,7 @@ class AGLinear(torch.autograd.Function):
)
grad_input = gemm_rs_op.forward(
grad_output,
weight,
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
weight_scale=None,
......@@ -302,13 +303,16 @@ class AGLinear(torch.autograd.Function):
)
torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_group(), batch_size, -1)
grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_world_size(), batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if ctx.sequence_parallel:
grad_output = grad_output.view(sequence_len, batch_size, output_hidden_size)
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
......@@ -503,25 +507,17 @@ class LinearRS(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size()
input_dim = input.dim()
sequence_len = input.size(0)
sequence_len, batch_size, _ = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(
input.shape[0] * input.shape[1], input.shape[2]
)
M = input.size(0)
N = weight.size(0)
input = input.view(sequence_len * batch_size, -1)
output_hidden_size = weight.size(0)
if sequence_parallel:
if transpose_weight:
weight = weight.t().contiguous()
gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(),
M,
N,
1, #world_size // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
......@@ -529,15 +525,23 @@ class LinearRS(torch.autograd.Function):
)
output = gemm_rs_op.forward(
input,
weight,
weight.t().contiguous() if transpose_weight else weight,
bias=bias,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# return _reduce_scatter_along_first_dim(output)
else:
output = torch.empty([M, N], dtype=input.dtype, device=input.device)
output_buf = torch.empty(
[sequence_len * batch_size, output_hidden_size],
dtype=input.dtype,
device=input.device,
requires_grad=False
)
gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype,
output_dtype=input.dtype,
......@@ -546,21 +550,18 @@ class LinearRS(torch.autograd.Function):
)
output = gemm_only_op.forward(
input,
weight,
weight.t().contiguous() if transpose_weight else weight,
bias=bias,
output_buf=output,
output_buf=output_buf,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
output = output.view(sequence_len, batch_size, -1)
output = _reduce(output)
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, input.size(0) // sequence_len, -1)
if not sequence_parallel:
_reduce(output)
# torch.cuda.current_stream().synchronize()
return output
@staticmethod
......@@ -579,37 +580,45 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer.append(grad_output)
wgrad_compute = False
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
if wgrad:
if ctx.sequence_parallel
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * world_size
sequence_len, batch_size, _ = grad_output.size()
grad_output = grad_output.view(sequence_len * batch_size, -1)
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, grad_output.dtype, "mpu"
)
handle = dist_all_gather_func(
all_gather_buffer, grad_output, group=get_tensor_model_parallel_group(), async_op=True
)
M, K = list(grad_output.size())
M = M * world_size
N = weight.size(-1)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_grad_output = all_gather_buffer
else:
total_grad_output = grad_output
if not transpose_weight:
weight = weight.t().contiguous()
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
grad_input = torch.empty([M, N], dtype=input.dtype, device=input.device)
sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
ag_kernel = flux.AGKernel(
get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(),
M,
N,
K,
input.dtype,
1, #world_size // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
input_hidden_size,
output_hidden_size,
grad_output.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
output = ag_kernel.forward(
grad_output,
weight,
grad_input = ag_kernel.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
weight_scale=None,
......@@ -617,24 +626,29 @@ class LinearRS(torch.autograd.Function):
fast_accum=False,
)
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
grad_input = grad_input.contiguous().view(sequence_len * world_size, batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, input
total_grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
total_grad_output, input
)
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
total_input, total_grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
total_input, total_grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
......@@ -662,8 +676,8 @@ class LinearRS(torch.autograd.Function):
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
grad_weight = total_grad_output.t().matmul(total_input)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
......@@ -952,19 +966,20 @@ class RowParallelLinearPatch(torch.nn.Module):
forward_params = {
"input": input_parallel,
"weight": self.weight,
"bias": None if not self.use_flux or self.skip_bias_add else self.bias,
"bias": self.bias if self.use_flux or not self.skip_bias_add else None,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if not self.use_flux else self.sequence_parallel,
"grad_output_buffer": False,
"grad_output_buffer": None,
}
if self.use_flux:
forward_params.update({"transpose_weight": self.flux_transpose_weight})
output_parallel = self._forward_impl(**forward_params)
if self.use_flux:
return output_parallel, None if skip_bias_add else self.bias
return output_parallel, None if not self.skip_bias_add else self.bias
# All-reduce across all the partitions.
if self.explicit_expert_comm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment