Commit da39d21b authored by Hongxin Liu's avatar Hongxin Liu Committed by ver217
Browse files

[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
parent c904d2ae
import torch import torch
import torch.nn as nn import torch.distributed as dist
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
class MixtralSparseMLP: class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
r""" def __init__(self, config):
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. super().__init__(config)
""" self.setup_ep()
def __init__(self) -> None: def setup_ep(self):
raise NotImplementedError( _, moe_info = MOE_MANAGER.get_info(self.num_experts)
"FusedLayerNorm is not implemented as a physical class. " ep_group = moe_info.ep_group
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
) self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
set_moe_tensor_info(p, moe_info)
@staticmethod @staticmethod
def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module: def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
r""" LazyInitContext.materialize(module)
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, module.__class__ = EPMixtralSparseMoeBlock
and optionally marking parameters for gradient aggregation. module.setup_ep()
return module
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
with torch.no_grad():
LazyInitContext.materialize(module)
# get the attributes of the module
moe_kwargs = dict(
num_experts=8,
hidden_size=module.hidden_dim,
intermediate_size=module.ffn_dim,
router_top_k=module.top_k,
router_norm=True,
router_loss=False,
# router_capacity_factor_train=
# router_capacity_factor_eval=
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance=
# load_balance_tolerance=
# load_balance_beam_width=
# load_balance_group_swap_factor=
enable_kernel=enable_kernel,
# enable_comm_overlap=
# enable_hierarchical_comm=
return_gate_logits=True,
)
dtype = module.gate.weight.dtype
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
return sparse_mlp def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module: selected_experts = selected_experts.t().reshape(-1)
""" selected_experts_idx = selected_experts.argsort()
Reverse the replace layer operation dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
Args: input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
module (torch.nn.Module): The object of layer to shard output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
""" output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
if isinstance(model, MixtralDecoderLayer): # compute expert output
model.block_sparse_moe = MixtralSparseMLP.from_native_module( output_states = MoeInGradScaler.apply(output_states, self.ep_size)
model.block_sparse_moe, enable_kernel=enable_kernel if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
) )
else: dispatch_states = dispatch_states[recover_experts_idx]
for _, child in model.named_children(): k_hidden_states = dispatch_states.chunk(self.top_k)
replace_moe_layer(child, enable_kernel) output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits
...@@ -20,6 +20,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col ...@@ -20,6 +20,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from .mixtral_layer import EPMixtralSparseMoeBlock
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
...@@ -51,6 +53,18 @@ class MixtralPolicy(Policy): ...@@ -51,6 +53,18 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
......
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import torch import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
...@@ -15,23 +14,6 @@ def move_to_cuda(batch, device): ...@@ -15,23 +14,6 @@ def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()} return {k: v.to(device) for k, v in batch.items()}
@torch.no_grad()
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
# pytorch ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
# saved ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
# download
else:
ckpt_path = snapshot_download(ckpt_path)
booster.load_model(model, ckpt_path)
if optimizer is not None:
optimizer.sync_moe_master_param()
optimizer.update_master_params(model)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
""" """
Load file in JSON format Load file in JSON format
...@@ -90,7 +72,7 @@ def load_checkpoint( ...@@ -90,7 +72,7 @@ def load_checkpoint(
""" """
# Update booster params states. # Update booster params states.
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer) booster.load_model(model, os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
......
...@@ -2,10 +2,8 @@ import argparse ...@@ -2,10 +2,8 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_model
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
...@@ -13,9 +11,6 @@ import colossalai ...@@ -13,9 +11,6 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
def parse_args(): def parse_args():
...@@ -30,16 +25,10 @@ def parse_args(): ...@@ -30,16 +25,10 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--plugin", "--plugin",
type=str, type=str,
default="hybrid", default="ep",
choices=["ep"], choices=["ep"],
help="Parallel methos.", help="Parallel methos.",
) )
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument( parser.add_argument(
"--precision", "--precision",
type=str, type=str,
...@@ -71,60 +60,38 @@ def main(): ...@@ -71,60 +60,38 @@ def main():
colossalai.launch_from_torch(config={}, seed=args.seed) colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator() coordinator = DistCoordinator()
config = MixtralConfig.from_pretrained(args.model_name)
ep_size = min(dist.get_world_size(), config.num_local_experts)
# Set plugin # Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"checkpoint_io": MixtralMoECheckpointIO,
"zero_stage": 1,
}
mgr_dict = {}
if args.plugin == "ep": if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1, pp_size=1,
**hybrid_dict, ep_size=ep_size,
) zero_stage=1,
MOE_MANAGER.setup( precision=args.precision,
parallel="EP", custom_policy=MixtralForCausalLMPolicy(),
max_ep_size=dp_size, checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
**mgr_dict, enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
) )
else: else:
raise ValueError(f"Invalid plugin {args.plugin}") raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build mixtral model # Build mixtral model
config = MixtralConfig.from_pretrained(args.model_name) model = MixtralForCausalLM.from_pretrained(args.model_name)
config.num_local_experts = 1 # dont change this. it will not affect model coordinator.print_on_master(f"Finish load model")
with skip_init():
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Replace moe
with skip_init():
replace_moe_layer(model)
model.eval()
coordinator.print_on_master(f"Finish replace moe module")
# Prepare tokenizer and dataloader # Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model) model, _, _, _, _ = booster.boost(model=model)
coordinator.print_on_master(f"Finish init booster") coordinator.print_on_master(f"Finish init booster")
# load ckpt model.eval()
load_model(args.model_name, model, booster)
coordinator.print_on_master(f"Finish load ckpt")
if coordinator.rank == 0: if coordinator.rank == 0:
text = ["Hello my name is"] text = ["Hello my name is"]
...@@ -132,10 +99,13 @@ def main(): ...@@ -132,10 +99,13 @@ def main():
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs) with torch.no_grad():
outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(f"[{coordinator.rank}] {outputs}") print(f"[{coordinator.rank}] {outputs}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.moe import MOE_MANAGER
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
MOE_MANAGER.setup(
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
)
torch.manual_seed(0)
orig_model = MixtralSparseMoeBlock(config).cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [2, 4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(2)
import os from copy import deepcopy
import shutil
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER from colossalai.testing.utils import spawn
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) def check_model_equal(model1, model2):
attention_mask = torch.ones_like(input_ids) assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return { return {
"input_ids": input_ids, "state": state,
"attention_mask": attention_mask, "param_groups": param_groups,
"labels": input_ids,
} }
def run_fwd_bwd( def check_optimizer_snapshot_equal(snapshot1, snapshot2):
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None # check param_groups
): assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
model.train() for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
if pipeline: assert set(group1.keys()) == set(group2.keys())
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1) for k in group1.keys():
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() assert group1[k] == group2[k]
y = booster.execute_pipeline( # check state
train_dataloader_iter, assert set(snapshot1["state"].keys()) == set(
model, snapshot2["state"].keys()
lambda x, y: x.loss, ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
optimizer, for pid in snapshot1["state"].keys():
return_loss=True, state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
return_outputs=True, assert set(state1.keys()) == set(state2.keys())
) for k in state1.keys():
# Backward and optimize if isinstance(state1[k], torch.Tensor):
if is_pp_last_stage: assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
loss = y["loss"] else:
else: assert state1[k] == state2[k]
if criterion:
y = model(data).logits
loss = criterion(y) def check_mixtral_moe_layer():
else: torch.cuda.set_device(dist.get_rank())
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config():
config = MixtralConfig( config = MixtralConfig(
vocab_size=300, hidden_size=hidden_size,
hidden_size=32, intermediate_size=hidden_size * 2,
intermediate_size=16, num_local_experts=n_experts,
num_hidden_layers=2, num_experts_per_tok=top_k,
dropout_rate=0.0, num_attention_heads=2,
num_key_value_heads=2,
) )
return config torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
def get_model(parallel): model = deepcopy(orig_model)
config = get_config() optimizer = Adam(model.parameters(), lr=1e-3)
model = MixtralForCausalLM(config).to(torch.bfloat16) plugin = MoeHybridParallelPlugin(
replace_moe_layer(model)
optim = torch.optim.Adam(model.parameters())
args = dict(
precision="bf16",
tp_size=1, tp_size=1,
zero_stage=1, pp_size=2,
ep_size=2,
custom_policy=MixtralForCausalLMPolicy(), custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO, checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
microbatch_size=1,
zero_stage=1,
) )
if parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=2,
microbatch_size=1,
**args,
)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim) model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
return model, booster, optim # initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
def _test_moe_checkpoint(parallel): # check save model
booster.save_model(model, "mixtral_model", shard=True)
dist.barrier()
if dist.get_rank() == 0: if dist.get_rank() == 0:
if os.path.exists("./tmp_ckpt1"): saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
shutil.rmtree("./tmp_ckpt1") check_model_equal(orig_model, saved_model)
if os.path.exists("./tmp_ckpt2"): saved_model.save_pretrained("mixtral_hf_model")
shutil.rmtree("./tmp_ckpt2")
dist.barrier() dist.barrier()
if parallel == None: # check load model
MOE_MANAGER.setup( new_model = MixtralForCausalLM(config).cuda()
parallel=None, new_optimizer = Adam(new_model.parameters(), lr=1e-3)
) new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
elif parallel == "ep": booster.load_model(new_model, "mixtral_hf_model")
MOE_MANAGER.setup( check_model_equal(model, new_model)
parallel="EP",
) # check save optimizer
elif parallel == "hybrid": optimizer.step()
MOE_MANAGER.setup( snapshot = get_optimizer_snapshot(optimizer.unwrap())
parallel="EP", booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
# param ckpt
# check not equal
try:
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
raise AssertionError("state_dict should not be equal")
except:
pass
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier() dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2") # reset optimizer state
# check for state in optimizer.unwrap().state.values():
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False) for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
if dist.get_rank() == 0:
shutil.rmtree("./tmp_ckpt1") def run_dist(rank: int, world_size: int, port: int):
shutil.rmtree("./tmp_ckpt2") colossalai.launch({}, rank, world_size, "localhost", port)
check_mixtral_moe_layer()
def _run_dist(rank, world_size, port, parallel):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
_test_moe_checkpoint(parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", ["ep", "hybrid"]) def test_mixtral_moe_layer(world_size: int):
@rerun_if_address_is_in_use() spawn(run_dist, world_size)
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid") test_mixtral_moe_layer(4)
import copy
import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
class Config:
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_act = hidden_act
def test_moe_layer():
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
mistral_moe = MixtralSparseMoeBlock(config).cuda()
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
data = torch.randn(2, 8, 4).cuda()
mistral_output = mistral_moe(data)[0]
colossal_output = colossal_moe(data)[0]
assert torch.allclose(
mistral_output, colossal_output
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
if __name__ == "__main__":
test_moe_layer()
...@@ -2,22 +2,18 @@ import argparse ...@@ -2,22 +2,18 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -153,45 +149,27 @@ def main(): ...@@ -153,45 +149,27 @@ def main():
coordinator = DistCoordinator() coordinator = DistCoordinator()
# Set plugin # Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
"checkpoint_io": MixtralMoECheckpointIO,
}
mgr_dict = {}
if args.plugin == "hybrid": if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=args.pp_size, pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size, microbatch_size=args.microbatch_size,
**hybrid_dict, custom_policy=MixtralForCausalLMPolicy(),
) enable_fused_normalization=args.use_layernorm_kernel,
MOE_MANAGER.setup( enable_jit_fused=args.use_kernel,
parallel="EP", precision=args.precision,
mode="fixed", zero_stage=args.zero_stage,
fixed_dp_size=args.dp_size, checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
) )
else: else:
raise ValueError(f"Invalid plugin {args.plugin}") raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model # Build Mixtral model
config = MixtralConfig.from_pretrained(args.model_name) model = MixtralForCausalLM.from_pretrained(args.model_name)
config.use_cache = False coordinator.print_on_master(f"Finish init model")
config.num_local_experts = 1
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model, enable_kernel=args.use_kernel)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing # Enable gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
...@@ -224,7 +202,7 @@ def main(): ...@@ -224,7 +202,7 @@ def main():
) )
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, lr_scheduler = booster.boost( model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
...@@ -236,10 +214,7 @@ def main(): ...@@ -236,10 +214,7 @@ def main():
coordinator.print_on_master(f"Finish init booster") coordinator.print_on_master(f"Finish init booster")
# Load ckpt # Load ckpt
if args.load_checkpoint is None: if args.load_checkpoint is not None:
load_model(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
else:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer") coordinator.print_on_master(f"Finish load optimizer")
...@@ -286,13 +261,13 @@ def main(): ...@@ -286,13 +261,13 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
# Apply load balance # Apply load balance
if ( # if (
args.load_balance # args.load_balance
and args.load_balance_interval > 0 # and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0 # and (step + 1) % args.load_balance_interval == 0
): # ):
coordinator.print_on_master(f"Apply load balance") # coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer) # apply_load_balance(model, optimizer)
# save ckeckpoint # save ckeckpoint
if (step + 1) % args.save_interval == 0: if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
......
...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
) )
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoECheckpintIO from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
...@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self, self,
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
ep_size: int,
extra_dp_size: int = 1, extra_dp_size: int = 1,
precision: str = "fp16", precision: str = "fp16",
zero_stage: int = 0, zero_stage: int = 0,
...@@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if enable_sequence_parallelism: if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision self.precision = precision
self.zero_stage = zero_stage self.zero_stage = zero_stage
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
......
...@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from .utils import has_index_file from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"] __all__ = ["CheckpointIO"]
...@@ -90,7 +90,15 @@ class CheckpointIO(ABC): ...@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists: if index_file_exists:
self.load_sharded_model(model, index_file_path, strict) self.load_sharded_model(model, index_file_path, strict)
else: else:
self.load_unsharded_model(model, checkpoint, strict) path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
return origin_model return origin_model
......
from typing import Any, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function): ...@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1: if ctx.ep_size != 1:
grad = grad / ctx.ep_size grad = grad / ctx.ep_size
return grad, None return grad, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape = list(inputs.shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
class AllToAllUneven(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
input_split_sizes=None,
output_split_sizes=None,
group=None,
overlap: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
@staticmethod
def backward(ctx: Any, *grad_outputs):
return (
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
None,
None,
None,
None,
)
def all_to_all_uneven(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
):
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
...@@ -26,3 +26,5 @@ class MoeParallelInfo: ...@@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
self.ep_rank = self.pg.coordinate(self.ep_axis)
self.dp_rank = self.pg.coordinate(self.dp_axis)
...@@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def sync_moe_master_param(self):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.
...@@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else: else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.master_to_working_param return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment