Commit c0d23a67 authored by dongcl's avatar dongcl
Browse files

move flux kernels outside

parent 2d1ebf8f
......@@ -168,7 +168,12 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import ColumnParallelLinearPatch, RowParallelLinearPatch, parallel_linear_init_wrapper
from ..core.tensor_parallel import (
ColumnParallelLinearPatch,
RowParallelLinearPatch,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
)
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
......@@ -198,12 +203,12 @@ class CoreAdaptation(MegatronAdaptationABC):
if HAS_FLUX:
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
# parallel_linear_init_wrapper,
# column_parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
# ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
......
......@@ -5,9 +5,9 @@ from typing import Callable, List, Optional
try:
import flux
HAS_FLUX = True
except ImportError:
from megatron.training import print_rank_0
print_rank_0(f"flux is NOT installed")
HAS_FLUX = False
import torch
import torch.nn.functional as F
......@@ -171,6 +171,8 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer,
wgrad_deferral_limit,
transpose_weight=False,
fw_ag_gemm_op=None,
bw_gemm_rs_op=None,
):
"""Forward."""
ctx.save_for_backward(input, weight)
......@@ -181,61 +183,41 @@ class AGLinear(torch.autograd.Function):
ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer
ctx.transpose_weight = transpose_weight
ctx.bw_gemm_rs_op = bw_gemm_rs_op
sequence_len, batch_size, input_hidden_size = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(
sequence_len * batch_size, input_hidden_size
)
output_hidden_size = weight.size(0)
if sequence_parallel:
sequence_len, batch_size, input_hidden_size = input.size()
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if transpose_weight:
weight = weight.t().contiguous()
if fw_ag_gemm_op is None:
fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
if sequence_parallel:
sequence_len = sequence_len * get_tensor_model_parallel_world_size()
ag_gemm_kernel = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size,
output_hidden_size,
input_hidden_size,
input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
output = ag_gemm_kernel.forward(
input,
weight,
output = fw_ag_gemm_op.forward(
input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight,
bias=bias,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False
)
output = output.view(sequence_len * world_size, batch_size, -1)
else:
output_buf = torch.empty([sequence_len * batch_size, output_hidden_size], dtype=input.dtype, device=input.device)
gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
use_fp8_gemm=False,
)
output = gemm_only_op.forward(
input,
weight,
bias=bias,
output_buf=output_buf,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
output = torch.matmul(input, weight.t())
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, batch_size, -1)
return output
......@@ -248,6 +230,7 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
transpose_weight = ctx.transpose_weight
bw_gemm_rs_op = ctx.bw_gemm_rs_op
wgrad_compute = True
if grad_output_buffer is not None:
......@@ -275,25 +258,23 @@ class AGLinear(torch.autograd.Function):
total_input = input
if ctx.sequence_parallel:
sequence_len, batch_size, output_hidden_size = grad_output.size()
input_hidden_size = weight.size(-1)
sequence_len, batch_size, _ = grad_output.size()
if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output = grad_output.view(
sequence_len * batch_size, output_hidden_size
)
gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
grad_input = gemm_rs_op.forward(
grad_output,
grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
......@@ -310,9 +291,6 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if ctx.sequence_parallel:
grad_output = grad_output.view(sequence_len, batch_size, output_hidden_size)
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
......@@ -323,8 +301,6 @@ class AGLinear(torch.autograd.Function):
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
......@@ -365,10 +341,10 @@ class AGLinear(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.allreduce_dgrad:
if not ctx.sequence_parallel and ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None
def ag_linear(
......@@ -381,6 +357,8 @@ def ag_linear(
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: Optional[int] = 0,
transpose_weight: Optional[bool] = False,
fw_ag_gemm_op=None,
bw_gemm_rs_op=None
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
......@@ -442,6 +420,11 @@ def ag_linear(
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
fw_ag_gemm_op: flux AGKernel for forward.
bw_gemm_rs_op: flux GemmRS for backward.
"""
args = [
......@@ -454,6 +437,8 @@ def ag_linear(
grad_output_buffer,
wgrad_deferral_limit,
transpose_weight,
fw_ag_gemm_op,
bw_gemm_rs_op,
]
if not ag_linear.warned:
......@@ -533,8 +518,6 @@ class LinearRS(torch.autograd.Function):
fast_accum=False,
)
output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# return _reduce_scatter_along_first_dim(output)
else:
output_buf = torch.empty(
[sequence_len * batch_size, output_hidden_size],
......@@ -791,7 +774,7 @@ def linear_rs(
linear_rs.warned = False
def parallel_linear_init_wrapper(fn):
def column_parallel_linear_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
......@@ -809,6 +792,13 @@ def parallel_linear_init_wrapper(fn):
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
if self.sequence_parallel:
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
else:
return wrapper
......@@ -884,6 +874,50 @@ class ColumnParallelLinearPatch(torch.nn.Module):
# Matrix multiply.
if self.use_flux:
assert HAS_FLUX, "flux is NOT installed"
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_fw_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
if (
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
)
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.previous_flux_params = current_fw_params
self._forward_impl = ag_linear
elif not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
......@@ -903,7 +937,11 @@ class ColumnParallelLinearPatch(torch.nn.Module):
"wgrad_deferral_limit": self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
}
if self.use_flux:
forward_params.update({"transpose_weight": self.flux_transpose_weight})
forward_params.update({
"transpose_weight": self.flux_transpose_weight,
"fw_ag_gemm_op": self.fw_ag_gemm_op,
"bw_gemm_rs_op": self.bw_gemm_rs_op,
})
output_parallel = self._forward_impl(**forward_params)
......@@ -922,6 +960,27 @@ class ColumnParallelLinearPatch(torch.nn.Module):
return output, output_bias
def row_parallel_linear_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# flux params
self.use_flux = False
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
return wrapper
class RowParallelLinearPatch(torch.nn.Module):
"""Linear layer with row parallelism.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment