Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
Pipeline #2152 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -201,7 +201,6 @@ def validate_args(args, defaults={}):
assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder."
if args.encoder_tensor_model_parallel_size > 0:
assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined."
assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0
assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder."
......@@ -401,6 +400,14 @@ def validate_args(args, defaults={}):
assert not args.use_dist_ckpt, \
'--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet'
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
args.main_grads_dtype = dtype_map[args.main_grads_dtype]
args.main_params_dtype = dtype_map[args.main_params_dtype]
args.exp_avg_dtype = dtype_map[args.exp_avg_dtype]
args.exp_avg_sq_dtype = dtype_map[args.exp_avg_sq_dtype]
if args.fp8_param_gather:
assert args.use_distributed_optimizer, \
'--fp8-param-gather only supported with distributed optimizer'
......@@ -422,7 +429,11 @@ def validate_args(args, defaults={}):
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
if args.accumulate_allreduce_grads_in_fp32:
assert args.main_grads_dtype == torch.float32, \
"--main-grads-dtype can only be fp32 when --accumulate-allreduce-grads-in-fp32 is set"
if not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
......@@ -2078,6 +2089,7 @@ def _add_vision_args(parser):
def _add_moe_args(parser):
group = parser.add_argument_group(title="moe")
# General arguments
group.add_argument('--expert-model-parallel-size', type=int, default=1,
help='Degree of expert model parallelism.')
group.add_argument('--expert-tensor-parallel-size', type=int, default=None,
......@@ -2103,16 +2115,23 @@ def _add_moe_args(parser):
help='Enable overlapping between shared expert computations and dispatcher communications. '
'Without this, the shared epxerts execute after the routed experts. '
'Only effective when moe-shared-expert-intermediate-size is set.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
# Router arguments
group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', 'sinkhorn', 'none'],
choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'],
default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
group.add_argument('--moe-router-topk-limited-devices', type=int, default=None,
help='Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. Default is None, which means no limited devices.')
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true',
help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None,
......@@ -2185,4 +2204,18 @@ def _add_experimental_args(parser):
'the overidden pattern')
group.add_argument('--yaml-cfg', type=str, default=None,
help = 'Config file to add additional arguments')
# Args of precision-aware optimizer
group.add_argument('--use-precision-aware-optimizer', action='store_true',
help='Use the precision-aware optimizer in TransformerEngine, which allows '
'setting the main params and optimizer states to lower precision, such as '
'fp16 and fp8.')
group.add_argument('--main-grads-dtype', default='fp32', choices=['fp32', 'bf16'],
help='Dtype of main grads when enabling precision-aware-optimizer')
group.add_argument('--main-params-dtype', default='fp32', choices=['fp32', 'fp16'],
help='Dtype of main params when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-sq-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg_sq when enabling precision-aware-optimizer')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment