"vscode:/vscode.git/clone" did not exist on "e5bfd387b55fe92640820c9ca62fea9bcafa2395"
Commit f25c421e authored by dongcl's avatar dongcl
Browse files

fix flux bug

parent b6eb1484
...@@ -143,9 +143,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -143,9 +143,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', # MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), # torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
...@@ -197,18 +197,17 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -197,18 +197,17 @@ class CoreAdaptation(MegatronAdaptationABC):
HAS_FLUX = False HAS_FLUX = False
if HAS_FLUX: if HAS_FLUX:
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__", # MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
parallel_linear_init_wrapper, # parallel_linear_init_wrapper,
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward", # MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward) # ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper, parallel_linear_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward) RowParallelLinearPatch.forward)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
......
...@@ -13,8 +13,11 @@ import torch ...@@ -13,8 +13,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
...@@ -31,12 +34,15 @@ from megatron.core.tensor_parallel.mappings import ( ...@@ -31,12 +34,15 @@ from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
) )
from megatron.core.tensor_parallel.utils import VocabUtility from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel.layers import ( from megatron.core.tensor_parallel.layers import (
custom_fwd, custom_fwd,
custom_bwd, custom_bwd,
dist_all_gather_func,
linear_with_frozen_weight, linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce linear_with_grad_accumulation_and_async_allreduce
) )
...@@ -176,26 +182,24 @@ class AGLinear(torch.autograd.Function): ...@@ -176,26 +182,24 @@ class AGLinear(torch.autograd.Function):
ctx.grad_output_buffer = grad_output_buffer ctx.grad_output_buffer = grad_output_buffer
ctx.transpose_weight = transpose_weight ctx.transpose_weight = transpose_weight
sequence_len = input.size(0) sequence_len, batch_size, input_hidden_size = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden] # input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view( input = input.view(
input.shape[0] * input.shape[1], input.shape[2] sequence_len * batch_size, input_hidden_size
) )
output_hidden_size = weight.size(0)
M, K = list(input.size())
N = weight.size(0)
M = M * get_tensor_model_parallel_world_size()
if transpose_weight: if transpose_weight:
weight = weight.t().contiguous() weight = weight.t().contiguous()
if sequence_parallel: if sequence_parallel:
sequence_len = sequence_len * get_tensor_model_parallel_world_size()
ag_gemm_kernel = flux.AGKernel( ag_gemm_kernel = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
get_tensor_model_parallel_world_size() // torch.cuda.device_count(), 1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
M, sequence_len * batch_size,
N, output_hidden_size,
K, input_hidden_size,
input.dtype, input.dtype,
output_dtype=input.dtype, output_dtype=input.dtype,
transpose_weight=transpose_weight, transpose_weight=transpose_weight,
...@@ -206,13 +210,13 @@ class AGLinear(torch.autograd.Function): ...@@ -206,13 +210,13 @@ class AGLinear(torch.autograd.Function):
input, input,
weight, weight,
bias=bias, bias=bias,
input_scale=input_scale, input_scale=None,
weight_scale=weight_scale, weight_scale=None,
output_scale=None, output_scale=None,
fast_accum=False fast_accum=False
) )
else: else:
output_buf = torch.empty([M, N], dtype=input.dtype, device=input.device) output_buf = torch.empty([sequence_len * batch_size, output_hidden_size], dtype=input.dtype, device=input.device)
gemm_only_op = flux.GemmOnly( gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype, input_dtype=input.dtype,
output_dtype=input.dtype, output_dtype=input.dtype,
...@@ -231,7 +235,7 @@ class AGLinear(torch.autograd.Function): ...@@ -231,7 +235,7 @@ class AGLinear(torch.autograd.Function):
) )
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, input.size(0) // sequence_len, -1) output = output.view(sequence_len, batch_size, -1)
return output return output
...@@ -272,20 +276,17 @@ class AGLinear(torch.autograd.Function): ...@@ -272,20 +276,17 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel: if ctx.sequence_parallel:
sequence_len, batch_size, output_hidden_size = grad_output.size() sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden] # input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output = grad_output.view( grad_output = grad_output.view(
sequence_len * batch_size, output_hidden_size sequence_len * batch_size, output_hidden_size
) )
if not transpose_weight:
weight = weight.t().contiguous()
gemm_rs_op = flux.GemmRS( gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(), 1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size, sequence_len * batch_size,
output_hidden_size, input_hidden_size,
input.dtype, input.dtype,
input.dtype, input.dtype,
transpose_weight=transpose_weight, transpose_weight=transpose_weight,
...@@ -293,7 +294,7 @@ class AGLinear(torch.autograd.Function): ...@@ -293,7 +294,7 @@ class AGLinear(torch.autograd.Function):
) )
grad_input = gemm_rs_op.forward( grad_input = gemm_rs_op.forward(
grad_output, grad_output,
weight, weight if transpose_weight else weight.t().contiguous(),
bias=None, bias=None,
input_scale=None, input_scale=None,
weight_scale=None, weight_scale=None,
...@@ -302,13 +303,16 @@ class AGLinear(torch.autograd.Function): ...@@ -302,13 +303,16 @@ class AGLinear(torch.autograd.Function):
) )
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_group(), batch_size, -1) grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_world_size(), batch_size, -1)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute: if ctx.sequence_parallel and wgrad_compute:
handle.wait() handle.wait()
if ctx.sequence_parallel:
grad_output = grad_output.view(sequence_len, batch_size, output_hidden_size)
if wgrad_compute: if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute( grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input grad_output, total_input
...@@ -503,25 +507,17 @@ class LinearRS(torch.autograd.Function): ...@@ -503,25 +507,17 @@ class LinearRS(torch.autograd.Function):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
input_dim = input.dim() sequence_len, batch_size, _ = input.size()
sequence_len = input.size(0)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden] # input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view( input = input.view(sequence_len * batch_size, -1)
input.shape[0] * input.shape[1], input.shape[2] output_hidden_size = weight.size(0)
)
M = input.size(0)
N = weight.size(0)
if sequence_parallel: if sequence_parallel:
if transpose_weight:
weight = weight.t().contiguous()
gemm_rs_op = flux.GemmRS( gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(), 1, #world_size // torch.cuda.device_count(),
M, sequence_len * batch_size,
N, output_hidden_size,
input.dtype, input.dtype,
input.dtype, input.dtype,
transpose_weight=transpose_weight, transpose_weight=transpose_weight,
...@@ -529,15 +525,23 @@ class LinearRS(torch.autograd.Function): ...@@ -529,15 +525,23 @@ class LinearRS(torch.autograd.Function):
) )
output = gemm_rs_op.forward( output = gemm_rs_op.forward(
input, input,
weight, weight.t().contiguous() if transpose_weight else weight,
bias=bias, bias=bias,
input_scale=None, input_scale=None,
weight_scale=None, weight_scale=None,
output_scale=None, output_scale=None,
fast_accum=False, fast_accum=False,
) )
output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# return _reduce_scatter_along_first_dim(output)
else: else:
output = torch.empty([M, N], dtype=input.dtype, device=input.device) output_buf = torch.empty(
[sequence_len * batch_size, output_hidden_size],
dtype=input.dtype,
device=input.device,
requires_grad=False
)
gemm_only_op = flux.GemmOnly( gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype, input_dtype=input.dtype,
output_dtype=input.dtype, output_dtype=input.dtype,
...@@ -546,21 +550,18 @@ class LinearRS(torch.autograd.Function): ...@@ -546,21 +550,18 @@ class LinearRS(torch.autograd.Function):
) )
output = gemm_only_op.forward( output = gemm_only_op.forward(
input, input,
weight, weight.t().contiguous() if transpose_weight else weight,
bias=bias, bias=bias,
output_buf=output, output_buf=output_buf,
input_scale=None, input_scale=None,
weight_scale=None, weight_scale=None,
output_scale=None, output_scale=None,
fast_accum=False, fast_accum=False,
) )
output = output.view(sequence_len, batch_size, -1)
output = _reduce(output)
torch.cuda.current_stream().synchronize() # torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, input.size(0) // sequence_len, -1)
if not sequence_parallel:
_reduce(output)
return output return output
@staticmethod @staticmethod
...@@ -579,37 +580,45 @@ class LinearRS(torch.autograd.Function): ...@@ -579,37 +580,45 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer.append(grad_output) grad_output_buffer.append(grad_output)
wgrad_compute = False wgrad_compute = False
if ctx.sequence_parallel: if wgrad:
world_size = get_tensor_model_parallel_world_size() if ctx.sequence_parallel
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * world_size
sequence_len, batch_size, _ = grad_output.size() all_gather_buffer = get_global_memory_buffer().get_tensor(
grad_output = grad_output.view(sequence_len * batch_size, -1) dim_size, grad_output.dtype, "mpu"
)
handle = dist_all_gather_func(
all_gather_buffer, grad_output, group=get_tensor_model_parallel_group(), async_op=True
)
M, K = list(grad_output.size()) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
M = M * world_size # gather is scheduled before the input gradient computation
N = weight.size(-1) total_grad_output = all_gather_buffer
else:
total_grad_output = grad_output
if not transpose_weight: if ctx.sequence_parallel:
weight = weight.t().contiguous() world_size = get_tensor_model_parallel_world_size()
grad_input = torch.empty([M, N], dtype=input.dtype, device=input.device) sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
ag_kernel = flux.AGKernel( ag_kernel = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
world_size // torch.cuda.device_count(), 1, #world_size // torch.cuda.device_count(),
M, sequence_len * batch_size * world_size,
N, input_hidden_size,
K, output_hidden_size,
input.dtype, grad_output.dtype,
output_dtype=input.dtype, output_dtype=input.dtype,
transpose_weight=transpose_weight, transpose_weight=transpose_weight,
local_copy=False, local_copy=False,
ring_mode=flux.AgRingMode.Auto, ring_mode=flux.AgRingMode.Auto,
) )
grad_input = ag_kernel.forward(
output = ag_kernel.forward( grad_output.view(sequence_len * batch_size, -1),
grad_output, weight if transpose_weight else weight.t().contiguous(),
weight,
bias=None, bias=None,
input_scale=None, input_scale=None,
weight_scale=None, weight_scale=None,
...@@ -617,24 +626,29 @@ class LinearRS(torch.autograd.Function): ...@@ -617,24 +626,29 @@ class LinearRS(torch.autograd.Function):
fast_accum=False, fast_accum=False,
) )
torch.distributed.barrier()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
grad_input = grad_input.contiguous().view(sequence_len * world_size, batch_size, -1)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if wgrad_compute: if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute( total_grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, input total_grad_output, input
) )
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
if wgrad_compute: if wgrad_compute:
if weight.main_grad.dtype == torch.float32: if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad total_input, total_grad_output, weight.main_grad
) )
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad total_input, total_grad_output, weight.main_grad
) )
else: else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
...@@ -662,8 +676,8 @@ class LinearRS(torch.autograd.Function): ...@@ -662,8 +676,8 @@ class LinearRS(torch.autograd.Function):
else: else:
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = total_grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
...@@ -952,19 +966,20 @@ class RowParallelLinearPatch(torch.nn.Module): ...@@ -952,19 +966,20 @@ class RowParallelLinearPatch(torch.nn.Module):
forward_params = { forward_params = {
"input": input_parallel, "input": input_parallel,
"weight": self.weight, "weight": self.weight,
"bias": None if not self.use_flux or self.skip_bias_add else self.bias, "bias": self.bias if self.use_flux or not self.skip_bias_add else None,
"gradient_accumulation_fusion": self.gradient_accumulation_fusion, "gradient_accumulation_fusion": self.gradient_accumulation_fusion,
"allreduce_dgrad": allreduce_dgrad, "allreduce_dgrad": allreduce_dgrad,
"sequence_parallel": False if not self.use_flux else self.sequence_parallel, "sequence_parallel": False if not self.use_flux else self.sequence_parallel,
"grad_output_buffer": False, "grad_output_buffer": None,
} }
if self.use_flux: if self.use_flux:
forward_params.update({"transpose_weight": self.flux_transpose_weight}) forward_params.update({"transpose_weight": self.flux_transpose_weight})
output_parallel = self._forward_impl(**forward_params) output_parallel = self._forward_impl(**forward_params)
if self.use_flux: if self.use_flux:
return output_parallel, None if skip_bias_add else self.bias return output_parallel, None if not self.skip_bias_add else self.bias
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.explicit_expert_comm: if self.explicit_expert_comm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment