Unverified Commit efef43b5 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5372 from hpcaitech/exp/mixtral

parents 4c03347f 06db94fb
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
super().__init__(config)
self.setup_ep()
def setup_ep(self):
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
ep_group = moe_info.ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
set_moe_tensor_info(p, moe_info)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_ep()
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits
This diff is collapsed.
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster.load_model(model, os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)
import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="ep",
choices=["ep"],
help="Parallel methos.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
config = MixtralConfig.from_pretrained(args.model_name)
ep_size = min(dist.get_world_size(), config.num_local_experts)
# Set plugin
if args.plugin == "ep":
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build mixtral model
model = MixtralForCausalLM.from_pretrained(args.model_name)
coordinator.print_on_master(f"Finish load model")
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Set booster
booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model)
coordinator.print_on_master(f"Finish init booster")
model.eval()
if coordinator.rank == 0:
text = ["Hello my name is"]
else:
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
with torch.no_grad():
outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(f"[{coordinator.rank}] {outputs}")
if __name__ == "__main__":
main()
NUM_GPU=2
MODEL="mistralai/Mixtral-8x7B-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
--model_name $MODEL \
--plugin "ep" \
colossalai >= 0.3.3
torch >= 1.8.1
transformers == 4.36.0
sentencepiece
datasets
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
with open("version.txt", "r") as f:
return f.read().strip()
setup(
name="colossal_moe",
version=fetch_version(),
packages=find_packages(
exclude=(
"tests",
"benchmarks",
"*.egg-info",
)
),
description="Colossal-AI MoE",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
url="https://github.com/hpcaitech",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.moe import MOE_MANAGER
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
MOE_MANAGER.setup(
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
)
torch.manual_seed(0)
orig_model = MixtralSparseMoeBlock(config).cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [2, 4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(2)
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return {
"state": state,
"param_groups": param_groups,
}
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
# check param_groups
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
assert set(group1.keys()) == set(group2.keys())
for k in group1.keys():
assert group1[k] == group2[k]
# check state
assert set(snapshot1["state"].keys()) == set(
snapshot2["state"].keys()
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
for pid in snapshot1["state"].keys():
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
assert set(state1.keys()) == set(state2.keys())
for k in state1.keys():
if isinstance(state1[k], torch.Tensor):
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
else:
assert state1[k] == state2[k]
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
)
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=2,
ep_size=2,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
microbatch_size=1,
zero_stage=1,
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
# initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
# check save model
booster.save_model(model, "mixtral_model", shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
check_model_equal(orig_model, saved_model)
saved_model.save_pretrained("mixtral_hf_model")
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, "mixtral_hf_model")
check_model_equal(model, new_model)
# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier()
# reset optimizer state
for state in optimizer.unwrap().state.values():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(4)
import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@torch.no_grad()
def get_global_loss(loss, booster):
global_loss = loss.clone().detach()
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
global_loss.div_(booster.plugin.dp_size)
return global_loss
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["hybrid"],
help="Parallel methods.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# lr scheduler
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model
model = MixtralForCausalLM.from_pretrained(args.model_name)
coordinator.print_on_master(f"Finish init model")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
collate_fn = None
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
# Set optimizer
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# Set lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
eta_min=0.1 * args.lr,
)
# Set booster
booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
if args.load_checkpoint is not None:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
global_loss = get_global_loss(loss, booster)
if coordinator._local_rank == "0":
pbar.set_postfix({"Loss": global_loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Apply load balance
# if (
# args.load_balance
# and args.load_balance_interval > 0
# and (step + 1) % args.load_balance_interval == 0
# ):
# coordinator.print_on_master(f"Apply load balance")
# apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
save_checkpoint(
args.output_path,
booster,
model,
optimizer,
lr_scheduler,
epoch,
step,
args.batch_size,
coordinator,
)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()
NUM_GPU=8
MODEL="mistralai/Mixtral-8x7B-v0.1"
SEQ_LENGTH=2048
BATCH_SIZE=1
LR=0.00001
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU \
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
train.py \
--num_epoch 1 \
--model_name $MODEL \
--plugin "hybrid" \
--batch_size $BATCH_SIZE \
--lr $LR \
--zero_stage 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 8 \
...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
) )
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoECheckpintIO from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
...@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self, self,
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
ep_size: int,
extra_dp_size: int = 1, extra_dp_size: int = 1,
precision: str = "fp16", precision: str = "fp16",
zero_stage: int = 0, zero_stage: int = 0,
...@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication: bool = True, overlap_communication: bool = True,
use_ep_inside: bool = True, use_ep_inside: bool = True,
custom_policy: Policy = None, custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None: ) -> None:
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
...@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if enable_sequence_parallelism: if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision self.precision = precision
self.zero_stage = zero_stage self.zero_stage = zero_stage
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
...@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance # we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
...@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
def get_checkpoint_io(self) -> MoECheckpintIO: def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io
def configure( def configure(
......
...@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from .utils import has_index_file from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"] __all__ = ["CheckpointIO"]
...@@ -90,7 +90,15 @@ class CheckpointIO(ABC): ...@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists: if index_file_exists:
self.load_sharded_model(model, index_file_path, strict) self.load_sharded_model(model, index_file_path, strict)
else: else:
self.load_unsharded_model(model, checkpoint, strict) path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
return origin_model return origin_model
......
from .checkpoint import MoECheckpintIO from .checkpoint import MoECheckpintIO
from .experts import MLPExperts from .experts import MLPExperts
from .layers import SparseMLP from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator from .utils import NormalNoiseGenerator, UniformNoiseGenerator
...@@ -14,4 +15,6 @@ __all__ = [ ...@@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator", "UniformNoiseGenerator",
"SparseMLP", "SparseMLP",
"MoECheckpintIO", "MoECheckpintIO",
"MOE_MANAGER",
"apply_load_balance",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment