Commit 595e428a authored by dongcl's avatar dongcl
Browse files

use flux gemmrs

parent 0867fd90
......@@ -187,7 +187,7 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True)
# flux
if os.getenv("USE_FLUX_OVERLAP", 0):
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import (
FluxColumnParallelLinear,
FluxRowParallelLinear
......
......@@ -16,6 +16,7 @@ from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
......@@ -40,7 +41,7 @@ def gpt_model_init_wrapper(fn):
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
config,
self.config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
......
import os
import socket
import warnings
from functools import wraps
from typing import Callable, List, Optional
......@@ -160,6 +161,19 @@ def vocab_parallel_embedding_forward(self, input_, weight=None):
return output
def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数
"""
if group is None:
group=get_tensor_model_parallel_group()
hostname = socket.gethostname()
hostnames = [None] * get_tensor_model_parallel_world_size()
torch.distributed.all_gather_object(hostnames, hostname, group=group)
num_nodes = len(set(hostnames))
return num_nodes
class AGLinear(torch.autograd.Function):
@staticmethod
@custom_fwd
......@@ -196,7 +210,7 @@ class AGLinear(torch.autograd.Function):
if fw_ag_gemm_op is None:
fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
......@@ -265,34 +279,31 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel:
sequence_len, batch_size, _ = grad_output.size()
# if bw_gemm_rs_op is None:
# input_hidden_size = weight.size(-1)
# bw_gemm_rs_op = flux.GemmRS(
# get_tensor_model_parallel_group(),
# 1, # world_size // torch.cuda.device_count(),
# sequence_len * batch_size,
# input_hidden_size,
# input.dtype,
# input.dtype,
# transpose_weight=transpose_weight,
# fuse_reduction=False
# )
# grad_input = bw_gemm_rs_op.forward(
# grad_output.view(sequence_len * batch_size, -1),
# weight if transpose_weight else weight.t().contiguous(),
# bias=None,
# input_scale=None,
# weight_scale=None,
# output_scale=None,
# fast_accum=False
# )
# torch.distributed.barrier()
# torch.cuda.current_stream().synchronize()
# grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
grad_input = grad_output.matmul(weight)
grad_input = _reduce_scatter_along_first_dim(grad_input)
if bw_gemm_rs_op is None:
input_hidden_size = weight.size(-1)
bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
input.dtype,
input.dtype,
transpose_weight=transpose_weight,
fuse_reduction=False
)
grad_input = bw_gemm_rs_op.forward(
grad_output.view(sequence_len * batch_size, -1),
weight if transpose_weight else weight.t().contiguous(),
bias=None,
input_scale=None,
weight_scale=None,
output_scale=None,
fast_accum=False
)
torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
......@@ -514,7 +525,7 @@ class LinearRS(torch.autograd.Function):
if fw_gemm_rs_op is None:
fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, #world_size // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input.dtype,
......@@ -522,6 +533,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight=transpose_weight,
fuse_reduction=False,
)
output = fw_gemm_rs_op.forward(
input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight,
......@@ -531,12 +543,8 @@ class LinearRS(torch.autograd.Function):
output_scale=None,
fast_accum=False,
)
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# output = _reduce_scatter_along_first_dim(output)
else:
output = torch.matmul(input, weight.t())
......@@ -586,7 +594,7 @@ class LinearRS(torch.autograd.Function):
if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, #world_size // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
output_hidden_size,
......@@ -605,10 +613,8 @@ class LinearRS(torch.autograd.Function):
output_scale=None,
fast_accum=False,
)
torch.distributed.barrier()
torch.cuda.current_stream().synchronize()
grad_input = grad_input.contiguous().view(sequence_len * world_size, batch_size, -1)
grad_input = grad_input.view(sequence_len * world_size, batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
......@@ -957,7 +963,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
):
self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
output_hidden_size,
input_hidden_size,
......@@ -970,7 +976,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size,
input_hidden_size,
input_parallel.dtype,
......@@ -1011,6 +1017,14 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size_per_partition}, bias={use_bias}, TP={tp})"
)
class FluxRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
......@@ -1131,7 +1145,7 @@ class FluxRowParallelLinear(RowParallelLinear):
):
self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
output_hidden_size,
input_parallel.dtype,
......@@ -1142,7 +1156,7 @@ class FluxRowParallelLinear(RowParallelLinear):
self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(),
get_tensor_model_parallel_node_size(),
sequence_len * batch_size,
input_hidden_size,
output_hidden_size,
......@@ -1184,3 +1198,11 @@ class FluxRowParallelLinear(RowParallelLinear):
output = output_
output_bias = self.bias
return output, output_bias
def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size_per_partition}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
......@@ -11,6 +11,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import Lang
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
......
......@@ -26,9 +26,6 @@ class ExtraTransformerConfig:
##################
# flux
##################
use_flux: bool = False
"""If set, flux will be used in ColumnParallelLinear and RowParallelLinear"""
flux_transpose_weight: bool = False
......
......@@ -182,9 +182,7 @@ def _add_mtp_args(parser):
def _add_flux_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--use-flux', action='store_true', default=False,
help='If set, flux will be used in ColumnParallelLinear and RowParallelLinear')
group = parser.add_argument_group(title='flux args')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel')
return parser
......@@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine" or bool(os.getenv("USE_FLUX_OVERLAP", 0))
use_te = args.transformer_impl == "transformer_engine" or bool(int(os.getenv("USE_FLUX_OVERLAP", "0")))
if args.record_memory_history:
torch.cuda.memory._record_memory_history(True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment