Commit 02eacdd3 authored by dongcl's avatar dongcl
Browse files

merge megatron_v0.11.0

parent 1f376414
...@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
......
...@@ -53,9 +53,6 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -53,9 +53,6 @@ from megatron.core.tensor_parallel.layers import (
linear_with_frozen_weight, linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
......
...@@ -19,11 +19,6 @@ try: # 使用定长fa ...@@ -19,11 +19,6 @@ try: # 使用定长fa
except ImportError: except ImportError:
flash_attn_func = None flash_attn_func = None
try:
from einops import rearrange
except ImportError:
rearrange = None
def parallel_mlp_init_wrapper(fn): def parallel_mlp_init_wrapper(fn):
@wraps(fn) @wraps(fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment