"src/diffusers/schedulers/scheduling_sde_ve_flax.py" did not exist on "5dda1735fda047f4242d28f91e6e457b9760d52d"
Commit c0d23a67 authored by dongcl's avatar dongcl
Browse files

move flux kernels outside

parent 2d1ebf8f
...@@ -168,7 +168,12 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -168,7 +168,12 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import ColumnParallelLinearPatch, RowParallelLinearPatch, parallel_linear_init_wrapper from ..core.tensor_parallel import (
ColumnParallelLinearPatch,
RowParallelLinearPatch,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
)
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
...@@ -198,12 +203,12 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -198,12 +203,12 @@ class CoreAdaptation(MegatronAdaptationABC):
if HAS_FLUX: if HAS_FLUX:
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__", # MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
# parallel_linear_init_wrapper, # column_parallel_linear_init_wrapper,
# apply_wrapper=True) # apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward", # MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
# ColumnParallelLinearPatch.forward) # ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper, row_parallel_linear_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward", MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward) RowParallelLinearPatch.forward)
......
...@@ -5,9 +5,9 @@ from typing import Callable, List, Optional ...@@ -5,9 +5,9 @@ from typing import Callable, List, Optional
try: try:
import flux import flux
HAS_FLUX = True
except ImportError: except ImportError:
from megatron.training import print_rank_0 HAS_FLUX = False
print_rank_0(f"flux is NOT installed")
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -171,6 +171,8 @@ class AGLinear(torch.autograd.Function): ...@@ -171,6 +171,8 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer, grad_output_buffer,
wgrad_deferral_limit, wgrad_deferral_limit,
transpose_weight=False, transpose_weight=False,
fw_ag_gemm_op=None,
bw_gemm_rs_op=None,
): ):
"""Forward.""" """Forward."""
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
...@@ -181,23 +183,18 @@ class AGLinear(torch.autograd.Function): ...@@ -181,23 +183,18 @@ class AGLinear(torch.autograd.Function):
ctx.wgrad_deferral_limit = wgrad_deferral_limit ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer ctx.grad_output_buffer = grad_output_buffer
ctx.transpose_weight = transpose_weight ctx.transpose_weight = transpose_weight
ctx.bw_gemm_rs_op = bw_gemm_rs_op
if sequence_parallel:
sequence_len, batch_size, input_hidden_size = input.size() sequence_len, batch_size, input_hidden_size = input.size()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input = input.view(
sequence_len * batch_size, input_hidden_size
)
output_hidden_size = weight.size(0) output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if transpose_weight: if fw_ag_gemm_op is None:
weight = weight.t().contiguous() fw_ag_gemm_op = flux.AGKernel(
if sequence_parallel:
sequence_len = sequence_len * get_tensor_model_parallel_world_size()
ag_gemm_kernel = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(), 1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size, sequence_len * batch_size * world_size,
output_hidden_size, output_hidden_size,
input_hidden_size, input_hidden_size,
input.dtype, input.dtype,
...@@ -206,36 +203,21 @@ class AGLinear(torch.autograd.Function): ...@@ -206,36 +203,21 @@ class AGLinear(torch.autograd.Function):
local_copy=False, local_copy=False,
ring_mode=flux.AgRingMode.Auto, ring_mode=flux.AgRingMode.Auto,
) )
output = ag_gemm_kernel.forward(
input, output = fw_ag_gemm_op.forward(
weight, input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight,
bias=bias, bias=bias,
input_scale=None, input_scale=None,
weight_scale=None, weight_scale=None,
output_scale=None, output_scale=None,
fast_accum=False fast_accum=False
) )
output = output.view(sequence_len * world_size, batch_size, -1)
else: else:
output_buf = torch.empty([sequence_len * batch_size, output_hidden_size], dtype=input.dtype, device=input.device) output = torch.matmul(input, weight.t())
gemm_only_op = flux.GemmOnly(
input_dtype=input.dtype,
output_dtype=input.dtype,
transpose_weight=transpose_weight,
use_fp8_gemm=False,
)
output = gemm_only_op.forward(
input,
weight,
bias=bias,
output_buf=output_buf,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False,
)
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
output = output.view(sequence_len, batch_size, -1)
return output return output
...@@ -248,6 +230,7 @@ class AGLinear(torch.autograd.Function): ...@@ -248,6 +230,7 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer = ctx.grad_output_buffer grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit wgrad_deferral_limit = ctx.wgrad_deferral_limit
transpose_weight = ctx.transpose_weight transpose_weight = ctx.transpose_weight
bw_gemm_rs_op = ctx.bw_gemm_rs_op
wgrad_compute = True wgrad_compute = True
if grad_output_buffer is not None: if grad_output_buffer is not None:
...@@ -275,14 +258,11 @@ class AGLinear(torch.autograd.Function): ...@@ -275,14 +258,11 @@ class AGLinear(torch.autograd.Function):
total_input = input total_input = input
if ctx.sequence_parallel: if ctx.sequence_parallel:
sequence_len, batch_size, output_hidden_size = grad_output.size() sequence_len, batch_size, _ = grad_output.size()
input_hidden_size = weight.size(-1)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden] if bw_gemm_rs_op is None:
grad_output = grad_output.view( input_hidden_size = weight.size(-1)
sequence_len * batch_size, output_hidden_size bw_gemm_rs_op = flux.GemmRS(
)
gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(), 1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size, sequence_len * batch_size,
...@@ -292,8 +272,9 @@ class AGLinear(torch.autograd.Function): ...@@ -292,8 +272,9 @@ class AGLinear(torch.autograd.Function):
transpose_weight=transpose_weight, transpose_weight=transpose_weight,
fuse_reduction=False fuse_reduction=False
) )
grad_input = gemm_rs_op.forward(
grad_output, grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(), weight if transpose_weight else weight.t().contiguous(),
bias=None, bias=None,
input_scale=None, input_scale=None,
...@@ -310,9 +291,6 @@ class AGLinear(torch.autograd.Function): ...@@ -310,9 +291,6 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel and wgrad_compute: if ctx.sequence_parallel and wgrad_compute:
handle.wait() handle.wait()
if ctx.sequence_parallel:
grad_output = grad_output.view(sequence_len, batch_size, output_hidden_size)
if wgrad_compute: if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute( grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input grad_output, total_input
...@@ -323,8 +301,6 @@ class AGLinear(torch.autograd.Function): ...@@ -323,8 +301,6 @@ class AGLinear(torch.autograd.Function):
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True grad_input, group=get_tensor_model_parallel_group(), async_op=True
) )
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
if wgrad_compute: if wgrad_compute:
...@@ -365,10 +341,10 @@ class AGLinear(torch.autograd.Function): ...@@ -365,10 +341,10 @@ class AGLinear(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.allreduce_dgrad: if not ctx.sequence_parallel and ctx.allreduce_dgrad:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None
def ag_linear( def ag_linear(
...@@ -381,6 +357,8 @@ def ag_linear( ...@@ -381,6 +357,8 @@ def ag_linear(
grad_output_buffer: Optional[List[torch.Tensor]] = None, grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: Optional[int] = 0, wgrad_deferral_limit: Optional[int] = 0,
transpose_weight: Optional[bool] = False, transpose_weight: Optional[bool] = False,
fw_ag_gemm_op=None,
bw_gemm_rs_op=None
) -> torch.Tensor: ) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and """Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop. gradient accumulation fusion in backprop.
...@@ -442,6 +420,11 @@ def ag_linear( ...@@ -442,6 +420,11 @@ def ag_linear(
deferred. Disable by setting this to 0. Defaults to 0. deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight. transpose_weight: transpose weight.
fw_ag_gemm_op: flux AGKernel for forward.
bw_gemm_rs_op: flux GemmRS for backward.
""" """
args = [ args = [
...@@ -454,6 +437,8 @@ def ag_linear( ...@@ -454,6 +437,8 @@ def ag_linear(
grad_output_buffer, grad_output_buffer,
wgrad_deferral_limit, wgrad_deferral_limit,
transpose_weight, transpose_weight,
fw_ag_gemm_op,
bw_gemm_rs_op,
] ]
if not ag_linear.warned: if not ag_linear.warned:
...@@ -533,8 +518,6 @@ class LinearRS(torch.autograd.Function): ...@@ -533,8 +518,6 @@ class LinearRS(torch.autograd.Function):
fast_accum=False, fast_accum=False,
) )
output = output.view(sequence_len // world_size, batch_size, -1) output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# return _reduce_scatter_along_first_dim(output)
else: else:
output_buf = torch.empty( output_buf = torch.empty(
[sequence_len * batch_size, output_hidden_size], [sequence_len * batch_size, output_hidden_size],
...@@ -791,7 +774,7 @@ def linear_rs( ...@@ -791,7 +774,7 @@ def linear_rs(
linear_rs.warned = False linear_rs.warned = False
def parallel_linear_init_wrapper(fn): def column_parallel_linear_init_wrapper(fn):
@wraps(fn) @wraps(fn)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
...@@ -809,6 +792,13 @@ def parallel_linear_init_wrapper(fn): ...@@ -809,6 +792,13 @@ def parallel_linear_init_wrapper(fn):
elif hasattr(self.config, "flux_transpose_weight"): elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight self.flux_transpose_weight = self.config.flux_transpose_weight
if self.sequence_parallel:
self.previous_flux_params = (None,) * 5
self.fw_ag_gemm_op = None
self.bw_gemm_rs_op = None
else:
return wrapper return wrapper
...@@ -884,6 +874,50 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -884,6 +874,50 @@ class ColumnParallelLinearPatch(torch.nn.Module):
# Matrix multiply. # Matrix multiply.
if self.use_flux: if self.use_flux:
assert HAS_FLUX, "flux is NOT installed"
sequence_len, batch_size, input_hidden_size = input_parallel.size()
output_hidden_size = weight.size(0)
world_size = get_tensor_model_parallel_world_size()
if self.sequence_parallel:
current_fw_params = (
sequence_len,
batch_size,
input_hidden_size,
output_hidden_size,
input_parallel.dtype
)
if (
self.fw_ag_gemm_op is None
or current_flux_params != self.previous_flux_params
)
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
input_parallel.dtype,
output_dtype=input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
local_copy=False,
ring_mode=flux.AgRingMode.Auto,
)
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
sequence_len * batch_size,
input_hidden_size,
input_parallel.dtype,
input_parallel.dtype,
transpose_weight=self.flux_transpose_weight,
fuse_reduction=False
)
self.previous_flux_params = current_fw_params
self._forward_impl = ag_linear self._forward_impl = ag_linear
elif not weight.requires_grad: elif not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight self._forward_impl = linear_with_frozen_weight
...@@ -903,7 +937,11 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -903,7 +937,11 @@ class ColumnParallelLinearPatch(torch.nn.Module):
"wgrad_deferral_limit": self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None, "wgrad_deferral_limit": self.config.wgrad_deferral_limit if self.config.defer_embedding_wgrad_compute else None,
} }
if self.use_flux: if self.use_flux:
forward_params.update({"transpose_weight": self.flux_transpose_weight}) forward_params.update({
"transpose_weight": self.flux_transpose_weight,
"fw_ag_gemm_op": self.fw_ag_gemm_op,
"bw_gemm_rs_op": self.bw_gemm_rs_op,
})
output_parallel = self._forward_impl(**forward_params) output_parallel = self._forward_impl(**forward_params)
...@@ -922,6 +960,27 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -922,6 +960,27 @@ class ColumnParallelLinearPatch(torch.nn.Module):
return output, output_bias return output, output_bias
def row_parallel_linear_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# flux params
self.use_flux = False
if "use_flux" in kwargs:
self.use_flux = kwargs["use_flux"]
elif hasattr(self.config, "use_flux"):
self.use_flux = self.config.use_flux
self.flux_transpose_weight = False
if "flux_transpose_weight" in kwargs:
self.flux_transpose_weight = kwargs["flux_transpose_weight"]
elif hasattr(self.config, "flux_transpose_weight"):
self.flux_transpose_weight = self.config.flux_transpose_weight
return wrapper
class RowParallelLinearPatch(torch.nn.Module): class RowParallelLinearPatch(torch.nn.Module):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment