Commit 7d8e0338 authored by Xuanlei Zhao's avatar Xuanlei Zhao Committed by ver217
Browse files

[moe] init mixtral impl

parent c53ddda8
import logging
import os
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
class MixtralMoECheckpointIO(MoECheckpintIO):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
model_param_dict = dict(model.named_parameters())
for name, param in list(state_dict.items()):
if ".gate.weight" in name:
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name:
# if is moe tensor
# in our moe module, expert is cat as one tensor
# but mixtral's experts is not cat
# we will insert the loaded expert into the position of cat tensor
# get model param
str_idx = name.index(".experts.")
expert_idx = int(name.split(".")[-3])
if ".w1." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
model_param_name = "module." + model_param_name
# skip for pipeline
if model_param_name not in model_param_dict:
continue
model_param = model_param_dict[model_param_name]
assert is_moe_tensor(model_param)
# get expert range
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = 8 // ep_size
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
# insert new param
if expert_idx in expert_range:
new_param = model_param
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
state_dict[model_param_name] = new_param
state_dict.pop(name)
else:
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
dist.barrier()
return state_dict
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
name = name.replace("module.", "")
name = name.replace(".gate_weight", ".gate.weight")
if ".experts.wi_gate" in name:
for i in range(8):
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
_load(new_name)
elif ".experts.wi_up" in name:
for i in range(8):
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
_load(new_name)
elif ".experts.wo" in name:
for i in range(8):
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
_load(new_name)
else:
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@torch.no_grad()
def pre_save_model(self, model: nn.Module) -> dict:
torch.cuda.empty_cache()
state_dict = model.state_dict()
for name, param in list(model.named_parameters()):
if ".gate_weight" in name:
new_name = name.replace(".gate_weight", ".gate.weight")
state_dict[new_name] = state_dict.pop(name).cpu()
elif ".experts." in name:
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
else:
state_dict[name] = param.cpu()
for name, param in list(state_dict.items()):
new_name = name.replace("module.", "")
state_dict[new_name] = state_dict.pop(name)
torch.cuda.empty_cache()
if self.pp_size > 1:
if self.dp_rank == 0:
# gather state_dict from every pp rank
# because ckpt is large, we split it into 10 parts
# and gather them one by one
new_state_dict = {}
state_dict_keys = list(state_dict.keys())
gap_key_num = min(30, len(state_dict_keys))
gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num
for i in range(gap_key_num):
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
cur_state_dict = {}
for k in cur_keys:
cur_state_dict[k] = state_dict[k]
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, cur_state_dict, group=self.pp_group)
if self.pp_rank == 0:
for o in out:
for k, v in o.items():
new_state_dict[k] = v.cpu()
state_dict = new_state_dict
dist.barrier()
return state_dict
import torch
import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP
class MixtralSparseMLP:
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module:
r"""
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
with torch.no_grad():
LazyInitContext.materialize(module)
# get the attributes of the module
moe_kwargs = dict(
num_experts=8,
hidden_size=module.hidden_dim,
intermediate_size=module.ffn_dim,
router_top_k=module.top_k,
router_norm=True,
router_loss=False,
# router_capacity_factor_train=
# router_capacity_factor_eval=
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance=
# load_balance_tolerance=
# load_balance_beam_width=
# load_balance_group_swap_factor=
enable_kernel=enable_kernel,
# enable_comm_overlap=
# enable_hierarchical_comm=
return_gate_logits=True,
)
dtype = module.gate.weight.dtype
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
return sparse_mlp
def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module:
"""
Reverse the replace layer operation
Args:
module (torch.nn.Module): The object of layer to shard
"""
if isinstance(model, MixtralDecoderLayer):
model.block_sparse_moe = MixtralSparseMLP.from_native_module(
model.block_sparse_moe, enable_kernel=enable_kernel
)
else:
for _, child in model.named_children():
replace_moe_layer(child, enable_kernel)
This diff is collapsed.
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
@torch.no_grad()
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
# pytorch ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
# saved ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
# download
else:
ckpt_path = snapshot_download(ckpt_path)
booster.load_model(model, ckpt_path)
if optimizer is not None:
optimizer.sync_moe_master_param()
optimizer.update_master_params(model)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)
import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_model
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["ep"],
help="Parallel methos.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"checkpoint_io": MixtralMoECheckpointIO,
"zero_stage": 1,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build mixtral model
config = MixtralConfig.from_pretrained(args.model_name)
config.num_local_experts = 1 # dont change this. it will not affect model
with skip_init():
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Replace moe
with skip_init():
replace_moe_layer(model)
model.eval()
coordinator.print_on_master(f"Finish replace moe module")
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, _, _, _, _ = booster.boost(model=model)
coordinator.print_on_master(f"Finish init booster")
# load ckpt
load_model(args.model_name, model, booster)
coordinator.print_on_master(f"Finish load ckpt")
text = ["Hello my name is", "1+1=?"]
tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs)[0]
print(outputs)
if __name__ == "__main__":
main()
NUM_GPU=2
MODEL="mistralai/Mixtral-8x7B-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
--model_name $MODEL \
--plugin "ep" \
colossalai >= 0.3.3
torch >= 1.8.1
transformers == 4.36.0
sentencepiece
datasets
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
with open("version.txt", "r") as f:
return f.read().strip()
setup(
name="colossal_moe",
version=fetch_version(),
packages=find_packages(
exclude=(
"tests",
"benchmarks",
"*.egg-info",
)
),
description="Colossal-AI MoE",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
url="https://github.com/hpcaitech",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)
import os
import shutil
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
}
def run_fwd_bwd(
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
model.train()
if pipeline:
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
y = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = y["loss"]
else:
if criterion:
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config():
config = MixtralConfig(
vocab_size=300,
hidden_size=32,
intermediate_size=16,
num_hidden_layers=2,
dropout_rate=0.0,
)
return config
def get_model(parallel):
config = get_config()
model = MixtralForCausalLM(config).to(torch.bfloat16)
replace_moe_layer(model)
optim = torch.optim.Adam(model.parameters())
args = dict(
precision="bf16",
tp_size=1,
zero_stage=1,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO,
)
if parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=2,
microbatch_size=1,
**args,
)
booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster, optim
def _test_moe_checkpoint(parallel):
if dist.get_rank() == 0:
if os.path.exists("./tmp_ckpt1"):
shutil.rmtree("./tmp_ckpt1")
if os.path.exists("./tmp_ckpt2"):
shutil.rmtree("./tmp_ckpt2")
dist.barrier()
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
)
elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
# param ckpt
# check not equal
try:
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
raise AssertionError("state_dict should not be equal")
except:
pass
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
if dist.get_rank() == 0:
shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt2")
def _run_dist(rank, world_size, port, parallel):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
_test_moe_checkpoint(parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid")
import copy
import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
class Config:
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_act = hidden_act
def test_moe_layer():
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
mistral_moe = MixtralSparseMoeBlock(config).cuda()
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
data = torch.randn(2, 8, 4).cuda()
mistral_output = mistral_moe(data)[0]
colossal_output = colossal_moe(data)[0]
assert torch.allclose(
mistral_output, colossal_output
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
if __name__ == "__main__":
test_moe_layer()
import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@torch.no_grad()
def get_global_loss(loss, booster):
global_loss = loss.clone().detach()
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
global_loss.div_(booster.plugin.dp_size)
return global_loss
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["hybrid"],
help="Parallel methods.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# lr scheduler
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
"checkpoint_io": MixtralMoECheckpointIO,
}
mgr_dict = {}
if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model
config = MixtralConfig.from_pretrained(args.model_name)
config.use_cache = False
config.num_local_experts = 1
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model, enable_kernel=args.use_kernel)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
collate_fn = None
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
# Set optimizer
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# Set lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
eta_min=0.1 * args.lr,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
if args.load_checkpoint is None:
load_model(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
else:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
global_loss = get_global_loss(loss, booster)
if coordinator._local_rank == "0":
pbar.set_postfix({"Loss": global_loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Apply load balance
if (
args.load_balance
and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0
):
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
save_checkpoint(
args.output_path,
booster,
model,
optimizer,
lr_scheduler,
epoch,
step,
args.batch_size,
coordinator,
)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()
NUM_GPU=8
MODEL="mistralai/Mixtral-8x7B-v0.1"
SEQ_LENGTH=2048
BATCH_SIZE=1
LR=0.00001
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU \
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
train.py \
--num_epoch 1 \
--model_name $MODEL \
--plugin "hybrid" \
--batch_size $BATCH_SIZE \
--lr $LR \
--zero_stage 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 8 \
......@@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
......@@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
......@@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
def get_checkpoint_io(self) -> MoECheckpintIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(
......
from .checkpoint import MoECheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
......@@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpintIO",
"MOE_MANAGER",
"apply_load_balance",
]
......@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
torch.cuda.empty_cache()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
......@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}."
)
dist.barrier()
torch.cuda.empty_cache()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
......@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
......@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param
# Read checkpoint index file.
......@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
......@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
......@@ -400,26 +400,33 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
device="cpu",
inplace=True,
)
optimizer.optim.state[param] = sharded_state
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
......@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
......@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
......@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment