Unverified Commit f71e63b0 authored by Xuanlei Zhao's avatar Xuanlei Zhao Committed by GitHub
Browse files

[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
parent 67f53317
...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
) )
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoeCheckpintIO from colossalai.moe import MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
...@@ -322,8 +322,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -322,8 +322,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
**_kwargs, **_kwargs,
) )
def get_checkpoint_io(self) -> MoeCheckpintIO: def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io
def configure( def configure(
...@@ -359,9 +359,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -359,9 +359,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
max_norm=self.max_norm, max_norm=self.max_norm,
**self.amp_config, **self.amp_config,
) )
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else: else:
optimizer = HybridParallelNaiveOptimizer( optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
......
...@@ -79,13 +79,15 @@ class TPInferEngine: ...@@ -79,13 +79,15 @@ class TPInferEngine:
self.multi_query_group_num = model.config.num_attention_heads self.multi_query_group_num = model.config.num_attention_heads
# default to attention_heads # default to attention_heads
self.multi_query_attention = model.config.multi_query_attention if hasattr(model.config, "multi_query_attention"):
self.multi_query_attention = getattr(model.config, "multi_query_attention")
if hasattr(model.config, "multi_query_group_num"): if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num self.multi_query_group_num = getattr(model.config, "multi_query_group_num")
if hasattr(model.config, "num_key_value_heads"): if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads self.multi_query_group_num = getattr(model.config, "num_key_value_heads")
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None self.cache_manager = None
...@@ -108,7 +110,7 @@ class TPInferEngine: ...@@ -108,7 +110,7 @@ class TPInferEngine:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads self.head_num //= self.tp_size # update sharded number of heads
if self.multi_query_attention: if hasattr(self, "multi_query_attention"):
# NOTE the logic of MQA tensor parallelism should be specified. # NOTE the logic of MQA tensor parallelism should be specified.
assert ( assert (
self.multi_query_group_num % self.tp_size == 0 self.multi_query_group_num % self.tp_size == 0
......
from .checkpoint import MoeCheckpintIO from .checkpoint import MoECheckpintIO
from .experts import MLPExperts from .experts import MLPExperts
from .layers import SparseMLP from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
...@@ -13,5 +13,5 @@ __all__ = [ ...@@ -13,5 +13,5 @@ __all__ = [
"NormalNoiseGenerator", "NormalNoiseGenerator",
"UniformNoiseGenerator", "UniformNoiseGenerator",
"SparseMLP", "SparseMLP",
"MoeCheckpintIO", "MoECheckpintIO",
] ]
This diff is collapsed.
...@@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler ...@@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
if HAS_TRITON: if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
...@@ -53,7 +53,8 @@ class MLPExperts(nn.Module): ...@@ -53,7 +53,8 @@ class MLPExperts(nn.Module):
# get expert parallel info # get expert parallel info
if expert_parallel is not None: if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False) num_experts, use_tp=True if expert_parallel == "TP" else False
)
# get settings for different parallel # get settings for different parallel
self.ep_size = get_ep_size(self) self.ep_size = get_ep_size(self)
if expert_parallel == "TP": if expert_parallel == "TP":
...@@ -87,7 +88,7 @@ class MLPExperts(nn.Module): ...@@ -87,7 +88,7 @@ class MLPExperts(nn.Module):
def reset_parameters(self): def reset_parameters(self):
# expert param should be different # expert param should be different
if self.expert_parallel is not None: if self.expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
else: else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx: with seed_ctx:
...@@ -129,7 +130,7 @@ class MLPExperts(nn.Module): ...@@ -129,7 +130,7 @@ class MLPExperts(nn.Module):
mask = torch.sum(mask, dim=-1) mask = torch.sum(mask, dim=-1)
x_list = [] x_list = []
for i in range(e): for i in range(e):
x_list.append(x[i, :mask[i]]) x_list.append(x[i, : mask[i]])
x = x_list x = x_list
if self.gated: if self.gated:
......
...@@ -8,14 +8,13 @@ from colossalai.tensor.moe_tensor.api import get_moe_info ...@@ -8,14 +8,13 @@ from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoeManager(metaclass=SingletonMeta): class MoEManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different """MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training. parallel groups in MoE context and MoE loss in training.
""" """
def __init__(self): def __init__(self):
self.parallel = None self.parallel = None
self.seed = None
self.mode = None self.mode = None
self.use_ep_inside = None self.use_ep_inside = None
self.world_size = None self.world_size = None
...@@ -48,7 +47,6 @@ class MoeManager(metaclass=SingletonMeta): ...@@ -48,7 +47,6 @@ class MoeManager(metaclass=SingletonMeta):
def setup( def setup(
self, self,
seed: int,
parallel: str = None, parallel: str = None,
mode: str = "dynamic", mode: str = "dynamic",
max_ep_size: int = 8, max_ep_size: int = 8,
...@@ -73,10 +71,9 @@ class MoeManager(metaclass=SingletonMeta): ...@@ -73,10 +71,9 @@ class MoeManager(metaclass=SingletonMeta):
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
""" """
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again" assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first" assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.seed = seed + dist.get_rank()
self.parallel = parallel self.parallel = parallel
self.use_ep_inside = use_ep_inside self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
...@@ -87,10 +84,12 @@ class MoeManager(metaclass=SingletonMeta): ...@@ -87,10 +84,12 @@ class MoeManager(metaclass=SingletonMeta):
if self.mode == "dynamic": if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size) self.max_ep_size = min(max_ep_size, self.world_size)
else: else:
assert (fixed_dp_size > 0 and fixed_ep_size > 0 assert (
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0" fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) ), "dp_size, ep_size and pp_size should be greater than 0"
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int" assert (
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size self.pp_size = fixed_pp_size
...@@ -112,10 +111,12 @@ class MoeManager(metaclass=SingletonMeta): ...@@ -112,10 +111,12 @@ class MoeManager(metaclass=SingletonMeta):
""" """
if self.mode == "dynamic": if self.mode == "dynamic":
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" assert gt_flag or lt_flag, (
" is not a multiple of ep size or vice versa.") "Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
dp_size = 1 if gt_flag else self.world_size // num_experts dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size) ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size dp_size = self.world_size // ep_size
...@@ -159,4 +160,4 @@ class MoeManager(metaclass=SingletonMeta): ...@@ -159,4 +160,4 @@ class MoeManager(metaclass=SingletonMeta):
return self.parallel return self.parallel
MOE_MANAGER = MoeManager() MOE_MANAGER = MoEManager()
...@@ -72,6 +72,19 @@ def get_ep_size(tensor: torch.Tensor) -> int: ...@@ -72,6 +72,19 @@ def get_ep_size(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_size return tensor.moe_info.ep_size
def get_dp_size(tensor: torch.Tensor) -> int:
"""
Get the data parallel size of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The data parallel size of the given tensor.
"""
return tensor.moe_info.dp_size
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
""" """
Get the data parallel group of the given tensor. Get the data parallel group of the given tensor.
......
...@@ -155,9 +155,7 @@ def main(): ...@@ -155,9 +155,7 @@ def main():
"precision": "bf16", "precision": "bf16",
"zero_stage": args.zero_stage, "zero_stage": args.zero_stage,
} }
mgr_dict = { mgr_dict = {}
"seed": 42,
}
if args.plugin == "ep": if args.plugin == "ep":
dp_size = dist.get_world_size() dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
......
...@@ -41,7 +41,7 @@ def fsdp_main(rank, world_size, args): ...@@ -41,7 +41,7 @@ def fsdp_main(rank, world_size, args):
# initialize the process group # initialize the process group
dist.init_process_group("nccl") dist.init_process_group("nccl")
MOE_MANAGER.setup(seed=42, parallel=None) MOE_MANAGER.setup(parallel=None)
dp_size = dist.get_world_size() dp_size = dist.get_world_size()
dataset = RandomDataset( dataset = RandomDataset(
......
colossalai >= 0.3.3 colossalai >= 0.3.3
torch >= 1.8.1 torch >= 1.8.1
transformers >= 4.20.0 transformers >= 4.20.0, <= 4.34.0
sentencepiece sentencepiece
datasets datasets
...@@ -213,9 +213,7 @@ def main(): ...@@ -213,9 +213,7 @@ def main():
"precision": args.precision, "precision": args.precision,
"zero_stage": args.zero_stage, "zero_stage": args.zero_stage,
} }
mgr_dict = { mgr_dict = {}
"seed": 42,
}
if args.plugin == "ep": if args.plugin == "ep":
dp_size = dist.get_world_size() dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
......
...@@ -6,10 +6,9 @@ import torch.nn as nn ...@@ -6,10 +6,9 @@ import torch.nn as nn
import colossalai import colossalai
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group from tests.test_moe.moe_utils import MoeGradientHandler
BATCH_SIZE = 4 BATCH_SIZE = 4
DIM = 16 DIM = 16
...@@ -25,7 +24,7 @@ def run_test(rank, world_size, port): ...@@ -25,7 +24,7 @@ def run_test(rank, world_size, port):
backend="nccl", backend="nccl",
) )
MOE_MANAGER.setup(42, parallel="EP") # MOE initialization MOE_MANAGER.setup(parallel="EP") # MOE initialization
num_experts_list = [1, 2, 4] num_experts_list = [1, 2, 4]
layer_list = [] layer_list = []
for num_experts in num_experts_list: for num_experts in num_experts_list:
...@@ -41,15 +40,6 @@ def run_test(rank, world_size, port): ...@@ -41,15 +40,6 @@ def run_test(rank, world_size, port):
model = nn.ModuleList(layer_list) model = nn.ModuleList(layer_list)
model = model.to(get_current_device()) model = model.to(get_current_device())
dist_dict = MOE_MANAGER.parallel_info_dict dist_dict = MOE_MANAGER.parallel_info_dict
assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
sync_moe_model_param(model)
assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
......
...@@ -20,21 +20,23 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f ...@@ -20,21 +20,23 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# Here we do not need TF32, since it brings absolute error on results # Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
local_rank = dist.get_rank() local_rank = dist.get_rank()
MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
MOE_MANAGER.reset_loss() MOE_MANAGER.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data # get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
layer = SparseMLP(hidden_size=hidden_size, layer = SparseMLP(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2, intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS, num_experts=NUM_EXPERTS,
router_top_k=topk, router_top_k=topk,
router_capacity_factor_train=1.0) router_capacity_factor_train=1.0,
)
layer = layer.to(get_current_device()) layer = layer.to(get_current_device())
if data_type == torch.float16: if data_type == torch.float16:
layer = layer.half() layer = layer.half()
...@@ -90,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, topk): ...@@ -90,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, topk):
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk) spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
if __name__ == '__main__': if __name__ == "__main__":
test_moe_kernel(2, 256, torch.float16, 2) test_moe_kernel(2, 256, torch.float16, 2)
...@@ -12,53 +12,112 @@ import colossalai ...@@ -12,53 +12,112 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
sys.path.append(os.path.join( sys.path.append(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe", "examples/language/openmoe",
)) )
)
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
}
def run_fwd_bwd(
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
model.train()
if pipeline:
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
y = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = y["loss"]
else:
if criterion:
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config(): def get_config():
config = LlamaConfig( config = LlamaConfig(
vocab_size=300, vocab_size=300,
hidden_size=16, hidden_size=16,
intermediate_size=32, intermediate_size=32,
num_hidden_layers=4, num_hidden_layers=2,
num_attention_heads=2, num_attention_heads=2,
head_dim=4, head_dim=4,
dropout_rate=0.0, dropout_rate=0.0,
hidden_act="swiglu", hidden_act="swiglu",
) )
set_openmoe_args(config, num_experts=16, moe_layer_interval=1) set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
return config return config
def get_model(parallel): def get_model(parallel):
config = get_config() config = get_config()
model = OpenMoeForCausalLM(config) model = OpenMoeForCausalLM(config)
optim = torch.optim.Adam(model.parameters())
if parallel == None: if parallel == None:
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1, tp_size=1,
pp_size=1, pp_size=1,
zero_stage=0, zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
) )
elif parallel == "zero_ep": elif parallel == "ep":
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1, tp_size=1,
pp_size=1, pp_size=1,
zero_stage=2, zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
) )
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "hybrid": elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1, tp_size=1,
pp_size=2, pp_size=2,
zero_stage=1, zero_stage=1,
...@@ -66,54 +125,77 @@ def get_model(parallel): ...@@ -66,54 +125,77 @@ def get_model(parallel):
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model) model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster return model, booster, optim
def _test_moe_checkpoint(parallel, shard): def _test_moe_checkpoint(rank, parallel):
if parallel == None: if parallel == None:
MOE_MANAGER.setup( MOE_MANAGER.setup(
seed=42,
parallel=None, parallel=None,
) )
elif parallel == "zero2_ep": elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup( MOE_MANAGER.setup(
seed=42,
parallel="EP", parallel="EP",
max_ep_size=2,
) )
elif parallel == "hybrid": elif parallel == "hybrid":
MOE_MANAGER.setup( MOE_MANAGER.setup(
seed=42,
parallel="EP", parallel="EP",
mode="fixed", mode="fixed",
fixed_dp_size=1, fixed_dp_size=1,
fixed_ep_size=2, fixed_ep_size=2,
fixed_pp_size=2, fixed_pp_size=2,
) )
model1, booster1 = get_model(parallel) model1, booster1, optim1 = get_model(parallel)
model2, booster2 = get_model(parallel) model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
if shard:
booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1) # param ckpt
booster2.load_model(model2, "./tmp_ckpt") # shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# unshard
booster1.save_model(model1, "./tmp_ckpt1.pth")
booster3.load_model(model3, "./tmp_ckpt1.pth")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else: else:
booster1.save_model(model1, "tmp_ckpt.pth") kwargs = {}
booster2.load_model(model2, "tmp_ckpt.pth") run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
state1 = model1.state_dict() optim1.zero_grad()
state2 = model2.state_dict() # shard
for k, v in state1.items(): booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
u = state2.get(k) dist.barrier()
assert torch.equal(u.data, v.data) booster2.load_optimizer(optim2, "./tmp_ckpt2")
# unshard
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
if dist.get_rank() == 0: if dist.get_rank() == 0:
if shard: shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt") shutil.rmtree("./tmp_ckpt2")
else: os.remove("./tmp_ckpt1.pth")
os.remove("tmp_ckpt.pth") os.remove("./tmp_ckpt2.pth")
def _run_dist(rank, world_size, port, parallel, shard): def _run_dist(rank, world_size, port, parallel):
colossalai.launch( colossalai.launch(
config=dict(), config=dict(),
rank=rank, rank=rank,
...@@ -122,17 +204,16 @@ def _run_dist(rank, world_size, port, parallel, shard): ...@@ -122,17 +204,16 @@ def _run_dist(rank, world_size, port, parallel, shard):
port=port, port=port,
backend="nccl", backend="nccl",
) )
_test_moe_checkpoint(parallel, shard) _test_moe_checkpoint(rank, parallel)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"]) @pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
@pytest.mark.parametrize("shard", [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel, shard): def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel, shard=shard) spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True) test_moe_checkpoint(world_size=4, parallel="hybrid")
...@@ -14,16 +14,16 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, syn ...@@ -14,16 +14,16 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, syn
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
assert batch_size % world_size == 0 assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel=None) MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="EP") MOE_MANAGER.setup(parallel="EP")
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="TP") MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_current_device()) ep_model = ep_model.to(get_current_device())
tp_model = tp_model.to(get_current_device()) tp_model = tp_model.to(get_current_device())
...@@ -44,7 +44,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size ...@@ -44,7 +44,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
tp_data = torch.randn(batch_size, dim, device=get_current_device()) tp_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size micro_batch_size = batch_size // world_size
ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)] ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)]
out_local = local_model(tp_data) out_local = local_model(tp_data)
MOE_MANAGER.reset_loss() MOE_MANAGER.reset_loss()
...@@ -52,8 +52,8 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size ...@@ -52,8 +52,8 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
MOE_MANAGER.reset_loss() MOE_MANAGER.reset_loss()
out_ep = ep_model(ep_data) out_ep = ep_model(ep_data)
MOE_MANAGER.reset_loss() MOE_MANAGER.reset_loss()
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)]) assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)]) assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)])
out_local.mean().backward() out_local.mean().backward()
out_tp.mean().backward() out_tp.mean().backward()
...@@ -77,5 +77,5 @@ def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): ...@@ -77,5 +77,5 @@ def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
if __name__ == '__main__': if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)
...@@ -15,7 +15,7 @@ INTERMEDIATE_SIZE = 8 ...@@ -15,7 +15,7 @@ INTERMEDIATE_SIZE = 8
def run_moe_init(expert_parallel): def run_moe_init(expert_parallel):
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=expert_parallel) MOE_MANAGER.setup(parallel=expert_parallel)
expert_args = dict( expert_args = dict(
hidden_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE, intermediate_size=INTERMEDIATE_SIZE,
......
...@@ -35,13 +35,13 @@ def run_zero_optim_test(local_rank, world_size, stage=1): ...@@ -35,13 +35,13 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
label = torch.randint(0, 4, (16,)).cuda() label = torch.randint(0, 4, (16,)).cuda()
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=None) MOE_MANAGER.setup(parallel=None)
torch_model = MoeModel() torch_model = MoeModel()
torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda() torch_model = torch_model.cuda()
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") MOE_MANAGER.setup(max_ep_size=2, use_ep_inside=False, parallel="EP")
zero_model = MoeModel() zero_model = MoeModel()
extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
......
...@@ -45,7 +45,6 @@ def run_zero_optim_test(local_rank, world_size, stage=1): ...@@ -45,7 +45,6 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup( MOE_MANAGER.setup(
seed=42,
parallel="EP", parallel="EP",
) )
zero_model = MoeModel(enable_load_balance=True) zero_model = MoeModel(enable_load_balance=True)
...@@ -55,7 +54,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): ...@@ -55,7 +54,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel="EP") MOE_MANAGER.setup(parallel="EP")
torch_model = MoeModel() torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data) torch_param.data.copy_(zero_param.data)
...@@ -94,7 +93,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): ...@@ -94,7 +93,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
zero_optimizer.step() zero_optimizer.step()
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" assert torch.allclose(zero_out, torch_out, atol=3e-5), f"zero_out:{zero_out}\ntorch_out{torch_out}"
def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
...@@ -103,14 +102,13 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): ...@@ -103,14 +102,13 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
label = torch.randint(0, 4, (16,)).cuda() label = torch.randint(0, 4, (16,)).cuda()
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=None) MOE_MANAGER.setup(parallel=None)
torch_model = MoeModel() torch_model = MoeModel()
torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda() torch_model = torch_model.cuda()
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup( MOE_MANAGER.setup(
seed=42,
max_ep_size=2, max_ep_size=2,
use_ep_inside=False, use_ep_inside=False,
parallel="EP", parallel="EP",
......
...@@ -88,7 +88,7 @@ def run_zero_test(local_rank, world_size, stage=1): ...@@ -88,7 +88,7 @@ def run_zero_test(local_rank, world_size, stage=1):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(seed=42, parallel="EP") MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank) seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1) run_zero_test(rank, world_size, stage=1)
run_zero_test(rank, world_size, stage=2) run_zero_test(rank, world_size, stage=2)
......
...@@ -76,7 +76,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): ...@@ -76,7 +76,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(seed=42, parallel="EP") MOE_MANAGER.setup(parallel="EP")
run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2) run_zero_optim_test(rank, world_size, stage=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment