Commit 595e428a authored by dongcl's avatar dongcl
Browse files

use flux gemmrs

parent 0867fd90
...@@ -187,7 +187,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -187,7 +187,7 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True) apply_wrapper=True)
# flux # flux
if os.getenv("USE_FLUX_OVERLAP", 0): if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import ( from ..core.tensor_parallel import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear FluxRowParallelLinear
......
...@@ -16,6 +16,7 @@ from megatron.core.packed_seq_params import PackedSeqParams ...@@ -16,6 +16,7 @@ from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from dcu_megatron.core.utils import tensor_slide from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
...@@ -40,7 +41,7 @@ def gpt_model_init_wrapper(fn): ...@@ -40,7 +41,7 @@ def gpt_model_init_wrapper(fn):
self.mtp_layers = torch.nn.ModuleList( self.mtp_layers = torch.nn.ModuleList(
[ [
MultiTokenPredictor( MultiTokenPredictor(
config, self.config,
self.mtp_spec.submodules, self.mtp_spec.submodules,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length, max_sequence_length=self.max_sequence_length,
......
import os import os
import socket
import warnings import warnings
from functools import wraps from functools import wraps
from typing import Callable, List, Optional from typing import Callable, List, Optional
...@@ -160,6 +161,19 @@ def vocab_parallel_embedding_forward(self, input_, weight=None): ...@@ -160,6 +161,19 @@ def vocab_parallel_embedding_forward(self, input_, weight=None):
return output return output
def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数
"""
if group is None:
group=get_tensor_model_parallel_group()
hostname = socket.gethostname()
hostnames = [None] * get_tensor_model_parallel_world_size()
torch.distributed.all_gather_object(hostnames, hostname, group=group)
num_nodes = len(set(hostnames))
return num_nodes
class AGLinear(torch.autograd.Function): class AGLinear(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
...@@ -196,7 +210,7 @@ class AGLinear(torch.autograd.Function): ...@@ -196,7 +210,7 @@ class AGLinear(torch.autograd.Function):
if fw_ag_gemm_op is None: if fw_ag_gemm_op is None:
fw_ag_gemm_op = flux.AGKernel( fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size, sequence_len * batch_size * world_size,
output_hidden_size, output_hidden_size,
input_hidden_size, input_hidden_size,
...@@ -265,34 +279,31 @@ class AGLinear(torch.autograd.Function): ...@@ -265,34 +279,31 @@ class AGLinear(torch.autograd.Function):
if ctx.sequence_parallel: if ctx.sequence_parallel:
sequence_len, batch_size, _ = grad_output.size() sequence_len, batch_size, _ = grad_output.size()
# if bw_gemm_rs_op is None: if bw_gemm_rs_op is None:
# input_hidden_size = weight.size(-1) input_hidden_size = weight.size(-1)
# bw_gemm_rs_op = flux.GemmRS( bw_gemm_rs_op = flux.GemmRS(
# get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
# 1, # world_size // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
# sequence_len * batch_size, sequence_len * batch_size,
# input_hidden_size, input_hidden_size,
# input.dtype, input.dtype,
# input.dtype, input.dtype,
# transpose_weight=transpose_weight, transpose_weight=transpose_weight,
# fuse_reduction=False fuse_reduction=False
# ) )
# grad_input = bw_gemm_rs_op.forward( grad_input = bw_gemm_rs_op.forward(
# grad_output.view(sequence_len * batch_size, -1), grad_output.view(sequence_len * batch_size, -1),
# weight if transpose_weight else weight.t().contiguous(), weight if transpose_weight else weight.t().contiguous(),
# bias=None, bias=None,
# input_scale=None, input_scale=None,
# weight_scale=None, weight_scale=None,
# output_scale=None, output_scale=None,
# fast_accum=False fast_accum=False
# ) )
# torch.distributed.barrier() torch.cuda.current_stream().synchronize()
# torch.cuda.current_stream().synchronize() grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
# grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
grad_input = grad_output.matmul(weight)
grad_input = _reduce_scatter_along_first_dim(grad_input)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
...@@ -514,7 +525,7 @@ class LinearRS(torch.autograd.Function): ...@@ -514,7 +525,7 @@ class LinearRS(torch.autograd.Function):
if fw_gemm_rs_op is None: if fw_gemm_rs_op is None:
fw_gemm_rs_op = flux.GemmRS( fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, #world_size // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size, sequence_len * batch_size,
output_hidden_size, output_hidden_size,
input.dtype, input.dtype,
...@@ -522,6 +533,7 @@ class LinearRS(torch.autograd.Function): ...@@ -522,6 +533,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight=transpose_weight, transpose_weight=transpose_weight,
fuse_reduction=False, fuse_reduction=False,
) )
output = fw_gemm_rs_op.forward( output = fw_gemm_rs_op.forward(
input.view(sequence_len * batch_size, -1), input.view(sequence_len * batch_size, -1),
weight.t().contiguous() if transpose_weight else weight, weight.t().contiguous() if transpose_weight else weight,
...@@ -531,12 +543,8 @@ class LinearRS(torch.autograd.Function): ...@@ -531,12 +543,8 @@ class LinearRS(torch.autograd.Function):
output_scale=None, output_scale=None,
fast_accum=False, fast_accum=False,
) )
torch.distributed.barrier()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
output = output.view(sequence_len // world_size, batch_size, -1) output = output.view(sequence_len // world_size, batch_size, -1)
# output = torch.matmul(input, weight.t())
# output = _reduce_scatter_along_first_dim(output)
else: else:
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
...@@ -586,7 +594,7 @@ class LinearRS(torch.autograd.Function): ...@@ -586,7 +594,7 @@ class LinearRS(torch.autograd.Function):
if bw_ag_gemm_op is None: if bw_ag_gemm_op is None:
bw_ag_gemm_op = flux.AGKernel( bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, #world_size // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size, sequence_len * batch_size * world_size,
input_hidden_size, input_hidden_size,
output_hidden_size, output_hidden_size,
...@@ -605,10 +613,8 @@ class LinearRS(torch.autograd.Function): ...@@ -605,10 +613,8 @@ class LinearRS(torch.autograd.Function):
output_scale=None, output_scale=None,
fast_accum=False, fast_accum=False,
) )
torch.distributed.barrier()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
grad_input = grad_input.contiguous().view(sequence_len * world_size, batch_size, -1) grad_input = grad_input.view(sequence_len * world_size, batch_size, -1)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
...@@ -957,7 +963,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -957,7 +963,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
): ):
self.fw_ag_gemm_op = flux.AGKernel( self.fw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size, sequence_len * batch_size * world_size,
output_hidden_size, output_hidden_size,
input_hidden_size, input_hidden_size,
...@@ -970,7 +976,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -970,7 +976,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self.bw_gemm_rs_op = flux.GemmRS( self.bw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size * world_size, sequence_len * batch_size * world_size,
input_hidden_size, input_hidden_size,
input_parallel.dtype, input_parallel.dtype,
...@@ -1011,6 +1017,14 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -1011,6 +1017,14 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size_per_partition}, bias={use_bias}, TP={tp})"
)
class FluxRowParallelLinear(RowParallelLinear): class FluxRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
...@@ -1131,7 +1145,7 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1131,7 +1145,7 @@ class FluxRowParallelLinear(RowParallelLinear):
): ):
self.fw_gemm_rs_op = flux.GemmRS( self.fw_gemm_rs_op = flux.GemmRS(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # world_size // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size, sequence_len * batch_size,
output_hidden_size, output_hidden_size,
input_parallel.dtype, input_parallel.dtype,
...@@ -1142,7 +1156,7 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1142,7 +1156,7 @@ class FluxRowParallelLinear(RowParallelLinear):
self.bw_ag_gemm_op = flux.AGKernel( self.bw_ag_gemm_op = flux.AGKernel(
get_tensor_model_parallel_group(), get_tensor_model_parallel_group(),
1, # torch.distributed.get_world_size() // torch.cuda.device_count(), get_tensor_model_parallel_node_size(),
sequence_len * batch_size, sequence_len * batch_size,
input_hidden_size, input_hidden_size,
output_hidden_size, output_hidden_size,
...@@ -1184,3 +1198,11 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -1184,3 +1198,11 @@ class FluxRowParallelLinear(RowParallelLinear):
output = output_ output = output_
output_bias = self.bias output_bias = self.bias
return output, output_bias return output, output_bias
def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size_per_partition}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
...@@ -11,6 +11,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import Lang ...@@ -11,6 +11,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import Lang
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.module import MegatronModule
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
......
...@@ -26,9 +26,6 @@ class ExtraTransformerConfig: ...@@ -26,9 +26,6 @@ class ExtraTransformerConfig:
################## ##################
# flux # flux
################## ##################
use_flux: bool = False
"""If set, flux will be used in ColumnParallelLinear and RowParallelLinear"""
flux_transpose_weight: bool = False flux_transpose_weight: bool = False
......
...@@ -182,9 +182,7 @@ def _add_mtp_args(parser): ...@@ -182,9 +182,7 @@ def _add_mtp_args(parser):
def _add_flux_args(parser): def _add_flux_args(parser):
group = parser.add_argument_group(title='multi token prediction') group = parser.add_argument_group(title='flux args')
group.add_argument('--use-flux', action='store_true', default=False,
help='If set, flux will be used in ColumnParallelLinear and RowParallelLinear')
group.add_argument('--flux-transpose-weight', action='store_true', default=False, group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel') help='Whether to transpose weight when using flux kernel')
return parser return parser
...@@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
""" """
args = get_args() args = get_args()
use_te = args.transformer_impl == "transformer_engine" or bool(os.getenv("USE_FLUX_OVERLAP", 0)) use_te = args.transformer_impl == "transformer_engine" or bool(int(os.getenv("USE_FLUX_OVERLAP", "0")))
if args.record_memory_history: if args.record_memory_history:
torch.cuda.memory._record_memory_history(True, torch.cuda.memory._record_memory_history(True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment