Unverified Commit efba0f44 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge pull request #4612 from hpcaitech/feature/shardformer

[shardformer] update hybrid parallel plugin and fix bugs
parents ac178ca5 fae6c92e
...@@ -208,7 +208,7 @@ jobs: ...@@ -208,7 +208,7 @@ jobs:
- name: Execute Unit Testing - name: Execute Unit Testing
run: | run: |
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/ CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
......
...@@ -44,7 +44,7 @@ jobs: ...@@ -44,7 +44,7 @@ jobs:
name: Test for PyTorch Compatibility name: Test for PyTorch Compatibility
needs: matrix_preparation needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu] runs-on: [self-hosted, 8-gpu]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
......
...@@ -35,7 +35,7 @@ jobs: ...@@ -35,7 +35,7 @@ jobs:
name: Test for PyTorch Compatibility name: Test for PyTorch Compatibility
needs: matrix_preparation needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu] runs-on: [self-hosted, 8-gpu]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
......
...@@ -32,7 +32,7 @@ jobs: ...@@ -32,7 +32,7 @@ jobs:
name: Test for PyTorch Compatibility name: Test for PyTorch Compatibility
needs: matrix_preparation needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu] runs-on: [self-hosted, 8-gpu]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
......
...@@ -14,29 +14,43 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer ...@@ -14,29 +14,43 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [ SFT_DATASET = [
{ {
"instruction": "Provide a list of the top 10 most popular mobile games in Asia", "instruction":
"input": "", "Provide a list of the top 10 most popular mobile games in Asia",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", "input":
"id": 0 "",
"output":
"The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id":
0
}, },
{ {
"instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", "instruction":
"input": "", "Please provide an action plan for reducing carbon footprint on a corporate level",
"output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", "input":
"id": 1 "",
"output":
"An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id":
1
}, },
{ {
"instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", "instruction":
"input": "", "Write a persuasive email to your boss explaining why you should have a pay raise",
"output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", "input":
"id": 2 "",
"output":
"Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id":
2
}, },
] ]
PROMPT_DATASET = [ PROMPT_DATASET = [
{ {
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", "instruction":
"id": 0 "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
"id":
0
}, },
{ {
"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
...@@ -73,9 +87,7 @@ def make_tokenizer(model: str): ...@@ -73,9 +87,7 @@ def make_tokenizer(model: str):
return tokenizer return tokenizer
def check_content(input_ids_stripped: torch.Tensor, def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
tokenizer: PreTrainedTokenizer,
model: str):
if model == "opt": if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt. # NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id assert input_ids_stripped[0] == tokenizer.eos_token_id
...@@ -98,13 +110,10 @@ def check_content(input_ids_stripped: torch.Tensor, ...@@ -98,13 +110,10 @@ def check_content(input_ids_stripped: torch.Tensor,
assert input_ids_stripped != tokenizer.mask_token_id assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2]) @pytest.mark.parametrize("max_datasets_size", [2])
def test_prompt_dataset(model: str, def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
max_datasets_size: int,
max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json" dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f: with open(os.path.join(tmp_dir, dataset_name), "w") as f:
...@@ -127,19 +136,12 @@ def test_prompt_dataset(model: str, ...@@ -127,19 +136,12 @@ def test_prompt_dataset(model: str,
check_content(input_ids.masked_select(attention_mask), tokenizer, model) check_content(input_ids.masked_select(attention_mask), tokenizer, model)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize(["dataset_path", "subset"], [ @pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)])
("Dahoas/rm-static", None)
])
@pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str, def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
dataset_path: str,
subset: Optional[str],
max_datasets_size: int,
max_length: int):
data = load_dataset(dataset_path, data_dir=subset) data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \ assert max_datasets_size <= len(data["train"]) \
and max_datasets_size <= len(data["test"]) and max_datasets_size <= len(data["test"])
...@@ -196,15 +198,12 @@ def test_reward_dataset(model: str, ...@@ -196,15 +198,12 @@ def test_reward_dataset(model: str,
assert torch.all(r_mask) assert torch.all(r_mask)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
def test_sft_dataset(model: str, def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
dataset_path: Optional[str],
max_dataset_size: int,
max_length: int):
tokenizer = make_tokenizer(model) tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct": if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions") data = load_dataset(dataset_path, "super_natural_instructions")
...@@ -253,10 +252,7 @@ def test_sft_dataset(model: str, ...@@ -253,10 +252,7 @@ def test_sft_dataset(model: str,
if __name__ == "__main__": if __name__ == "__main__":
test_sft_dataset(model="bloom", test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
dataset_path="yizhongw/self_instruct",
max_dataset_size=2,
max_length=256)
test_reward_dataset(model="gpt2", test_reward_dataset(model="gpt2",
dataset_path="Anthropic/hh-rlhf", dataset_path="Anthropic/hh-rlhf",
...@@ -266,4 +262,5 @@ if __name__ == "__main__": ...@@ -266,4 +262,5 @@ if __name__ == "__main__":
test_prompt_dataset(model="opt", test_prompt_dataset(model="opt",
max_datasets_size=2, max_datasets_size=2,
max_length=128) max_length=128)
\ No newline at end of file
...@@ -16,17 +16,19 @@ from coati.models.opt import OPTRM, OPTActor, OPTCritic ...@@ -16,17 +16,19 @@ from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32]) @pytest.mark.parametrize("seq_len", [32])
@pytest.mark.parametrize("actor_maker", [ @pytest.mark.parametrize(
lambda: BLOOMActor(), "actor_maker",
lambda: GPTActor(), [
lambda: BLOOMActor(),
lambda: GPTActor(),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: LlamaActor(), # lambda: LlamaActor(),
lambda: OPTActor(), lambda: OPTActor(),
# lambda: ChatGLMActor(), # lambda: ChatGLMActor(),
]) ])
@pytest.mark.parametrize("generate_kwargs", [{ @pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64, "max_length": 64,
"use_cache": True, "use_cache": True,
...@@ -34,23 +36,15 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer ...@@ -34,23 +36,15 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
"temperature": 1.0, "temperature": 1.0,
"top_k": 50, "top_k": 50,
}]) }])
def test_generation(actor_maker: Callable[[], Actor], def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
batch_size: int,
seq_len: int,
generate_kwargs: Dict[str, Any]
):
actor = actor_maker() actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs) sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"]) assert sequences.shape == (batch_size, generate_kwargs["max_length"])
@pytest.mark.cpu
def test_utils(): def test_utils():
fn_input = { fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
"tensor": torch.ones((10, )),
"mask": torch.randint(0, 2, (10, ))
}
fn_output = masked_mean(dim=0, **fn_input) fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0 assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0)) assert torch.allclose(fn_output, torch.tensor(1.0))
...@@ -58,14 +52,14 @@ def test_utils(): ...@@ -58,14 +52,14 @@ def test_utils():
batch_size = 4 batch_size = 4
num_labels = 10 num_labels = 10
fn_input = { fn_input = {
"r": torch.ones((batch_size, )), "r": torch.ones((batch_size,)),
"kl_coef": 1.0, "kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)), "log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels)) "action_mask": torch.randint(0, 2, (batch_size, num_labels))
} }
fn_output = compute_reward(**fn_input) fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size, ) assert fn_output.shape == (batch_size,)
batch_size = 4 batch_size = 4
seq_len = 32 seq_len = 32
...@@ -82,17 +76,11 @@ def test_utils(): ...@@ -82,17 +76,11 @@ def test_utils():
assert fn_output.shape == (batch_size, num_actions) assert fn_output.shape == (batch_size, num_actions)
@pytest.mark.cpu
@pytest.mark.parametrize("lora_rank", [4]) @pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32]) @pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4]) @pytest.mark.parametrize("num_layers", [4])
def test_lora(lora_rank: int, def test_lora(lora_rank: int, num_dim: int, num_layers: int):
num_dim: int, model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
num_layers: int):
model = nn.ModuleList(
[nn.Linear(num_dim, num_dim)
for _ in range(num_layers)]
)
lora_model = convert_to_lora_module(model, lora_rank) lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList) assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers): for i in range(num_layers):
...@@ -105,8 +93,7 @@ def test_lora(lora_rank: int, ...@@ -105,8 +93,7 @@ def test_lora(lora_rank: int,
assert isinstance(lora_model[i], LoraLinear) assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias) assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters()) optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim) x = torch.randn(8, num_dim)
for i in range(num_layers): for i in range(num_layers):
...@@ -122,12 +109,13 @@ def test_lora(lora_rank: int, ...@@ -122,12 +109,13 @@ def test_lora(lora_rank: int,
lora_model[i].lora_B @ lora_model[i].lora_A) lora_model[i].lora_B @ lora_model[i].lora_A)
@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("models_maker", [ @pytest.mark.parametrize(
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), "models_maker",
lambda: (GPTActor(), GPTCritic(), GPTRM()), [
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()), lambda: (OPTActor(), OPTCritic(), OPTRM()),
...@@ -178,13 +166,10 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], ...@@ -178,13 +166,10 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
assert rm_output.shape == (batch_size, ) assert rm_output.shape == (batch_size, )
@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100]) @pytest.mark.parametrize("num_labels", [100])
def test_loss(batch_size: int, def test_loss(batch_size: int, seq_len: int, num_labels: int):
seq_len: int,
num_labels: int):
loss = GPTLMLoss() loss = GPTLMLoss()
loss_input = { loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels), "logits": torch.randn(batch_size, seq_len, num_labels),
...@@ -194,54 +179,43 @@ def test_loss(batch_size: int, ...@@ -194,54 +179,43 @@ def test_loss(batch_size: int,
loss = PolicyLoss() loss = PolicyLoss()
loss_input = { loss_input = {
"log_probs": torch.randn(batch_size, ), "log_probs": torch.randn(batch_size,),
"old_log_probs": torch.randn(batch_size, ), "old_log_probs": torch.randn(batch_size,),
"advantages": torch.randn(batch_size, ) "advantages": torch.randn(batch_size,)
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
loss = ValueLoss() loss = ValueLoss()
loss_input = { loss_input = {
"values": torch.randn(batch_size, ), "values": torch.randn(batch_size,),
"old_values": torch.randn(batch_size, ), "old_values": torch.randn(batch_size,),
"reward": torch.randn(batch_size, ) "reward": torch.randn(batch_size,)
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
loss = LogSigLoss() loss = LogSigLoss()
loss_input = { loss_input = {
"chosen_reward": torch.randn(batch_size, ), "chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size, ), "reject_reward": torch.randn(batch_size,),
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
loss = LogExpLoss() loss = LogExpLoss()
loss_input = { loss_input = {
"chosen_reward": torch.randn(batch_size, ), "chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size, ), "reject_reward": torch.randn(batch_size,),
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
if __name__ == "__main__": if __name__ == "__main__":
generate_kwargs = dict(max_length=40, generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
use_cache=True, test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
do_sample=True,
temperature=1.0,
top_k=50)
test_generation(lambda: LlamaActor(),
batch_size=4,
seq_len=32,
generate_kwargs=generate_kwargs)
test_utils() test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2) test_lora(lora_rank=2, num_dim=8, num_layers=2)
test_models(models_maker=lambda: (BLOOMActor(), test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
BLOOMCritic(),
BLOOMRM()),
batch_size=8,
seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100) test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
...@@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import ( ...@@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
load_shard_state_dict, load_shard_state_dict,
save_config_file,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
) )
...@@ -107,6 +108,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -107,6 +108,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. " logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.") f"index located at {save_index_file}.")
......
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union from functools import partial
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import Module from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
...@@ -26,26 +29,52 @@ from .pp_plugin_base import PipelinePluginBase ...@@ -26,26 +29,52 @@ from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(dtype)
return x
class HybridParallelModule(ModelWrapper): class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
self.stage_manager = shard_config.pipeline_stage_manager self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group self.dp_group = dp_group
shardformer = ShardFormer(shard_config) shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module) module, self.shared_params = shardformer.optimize(module)
# TODO(ver217): add input type cast
# setting process groups for shared parameters
self.shared_param_process_groups = [] self.shared_param_process_groups = []
for shared_param in self.shared_params: for shared_param in self.shared_params:
if len(shared_param) > 0: if len(shared_param) > 0:
self.shared_param_process_groups.append( self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
# setting mixed_precision
self.mixed_precision = None
if precision == 'fp16': if precision == 'fp16':
module = module.half().cuda() self.mixed_precision = torch.float16
elif precision == 'bf16': elif precision == 'bf16':
module = module.to(dtype=torch.bfloat16).cuda() self.mixed_precision = torch.bfloat16
else: if self.mixed_precision is not None:
module = module.cuda() # train without AMP module = module.to(self.mixed_precision)
# TODO(ver217): support TP+DP module = module.cuda()
# setting input type cast when using mixed precision
self.convert_fn = None
if self.mixed_precision is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
# setting ddp configs
if use_ddp:
# convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module) super().__init__(module)
def sync_shared_params(self): def sync_shared_params(self):
...@@ -68,19 +97,62 @@ class HybridParallelModule(ModelWrapper): ...@@ -68,19 +97,62 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(p.grad, group=self.dp_group) dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size()) p.grad.div_(self.dp_group.size())
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != 'params'}
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id)
param_info['param2id'][id(param)] = param_id
param_info['id2param'][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape
param_info['param_groups'].append(packed_group)
start_index += len(group['params'])
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module): def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters()) model_params = set(model.parameters())
new_param_groups = [] new_param_groups = []
for group in optim.param_groups: for group in optim.param_groups:
params = [p for p in group['params'] if p in params] params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params}) new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups}) optim.__setstate__({'param_groups': new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper): class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__(optim) super().__init__(optim)
...@@ -92,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -92,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
optim: Optimizer, optim: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict,
precision: str = 'fp16', precision: str = 'fp16',
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
...@@ -101,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -101,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0): max_norm: float = 0):
self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
...@@ -114,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -114,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1, min_scale: int = 1,
growth_factor: float = 2., growth_factor: float = 2.,
...@@ -131,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -131,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None): forced_dtype: Optional[torch.dtype] = None):
self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optimizer, model) init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
...@@ -140,34 +216,100 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -140,34 +216,100 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class HybridParallelPlugin(PipelinePluginBase): class HybridParallelPlugin(PipelinePluginBase):
"""
Plugin for Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
"""
def __init__(self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True) -> None:
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
) -> None:
super().__init__() super().__init__()
assert dist.get_world_size() % ( assert dist.get_world_size() % (
tp_size * pp_size tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
# TODO(ver217): support zero
assert zero_stage == 0, 'zero is not support yet' if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
...@@ -178,24 +320,30 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -178,24 +320,30 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_fused_normalization = enable_fused_normalization self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1, enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization, enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization, enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention, enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused) enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
...@@ -205,6 +353,20 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -205,6 +353,20 @@ class HybridParallelPlugin(PipelinePluginBase):
min_scale=min_scale, min_scale=min_scale,
max_scale=max_scale, max_scale=max_scale,
) )
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2))
self.max_norm = max_norm self.max_norm = max_norm
@property @property
...@@ -237,32 +399,44 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -237,32 +399,44 @@ class HybridParallelPlugin(PipelinePluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0: if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']: if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer, optimizer = HybridParallelAMPOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision, precision=self.precision,
max_norm=self.max_norm, max_norm=self.max_norm,
**self.amp_config) **self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
else: else:
optimizer = HybridParallelNaiveOptimizer(optimizer, optimizer = HybridParallelNaiveOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism) use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer, optimizer = HybridParallelZeroOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
partition_grad=(self.zero_stage == 2), param_info=param_info,
cpu_offload=self.cpu_offload,
dp_process_group=self.dp_group, dp_process_group=self.dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
verbose=True, verbose=True,
clip_grad_norm=self.max_norm, clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config) **self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self, def execute_pipeline(self,
...@@ -339,7 +513,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -339,7 +513,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs) **_kwargs)
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return None self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
...@@ -23,6 +23,7 @@ from .utils import ( ...@@ -23,6 +23,7 @@ from .utils import (
load_state_dict, load_state_dict,
load_state_dict_into_model, load_state_dict_into_model,
load_states_into_optimizer, load_states_into_optimizer,
save_config_file,
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
...@@ -183,6 +184,7 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -183,6 +184,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. " logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.") f"index located at {save_index_file}.")
......
import copy
import gc
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class HypridParallelCheckpointIO(GeneralCheckpointIO):
"""
CheckpointIO for Hybrid Parallel Training.
Args:
dp_group (ProcessGroup): Process group along data parallel dimension.
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
def __init__(self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.dp_rank = dist.get_rank(self.dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = (zero_stage > 0)
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
@staticmethod
def _model_sharder(model: nn.Module,
prefix: str = '',
keep_vars: bool = False,
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Save buffers.
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
def _optimizer_sharder(optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)]
state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0)
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for weight, weight_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
with torch.no_grad():
if self.working_to_master_map is not None:
for param in model.parameters():
if (param is None) or (id(param) not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[id(param)]
if self.use_zero:
# master_param is sharded under Zero setting
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
if padding_size > 0:
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padded_param = param.data.view(-1)
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
master_param.data.copy_(sharded_param.data)
else:
master_param.data.copy_(param.data)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
size_per_shard=size_per_shard)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose:
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
def _get_param_id_from_optimizer_param(param: torch.Tensor,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info['param2id'][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg['params']:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({'param_groups': updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg['params']:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if self.master_to_working_map is not None:
working_param = self.master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info['param2shape'][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose:
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save lr scheduler to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
The created mappings should be mappings from integer parameter addresses to parameter objects.
Args:
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
"""
self.working_to_master_map = dict()
for k, v in working_to_master_map.items():
if isinstance(k, torch.Tensor):
self.working_to_master_map[id(k)] = v
elif isinstance(k, int):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
if isinstance(k, torch.Tensor):
self.master_to_working_map[id(k)] = v
elif isinstance(k, int):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
@staticmethod
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
inplace: bool) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
tp_size = dist.get_world_size(tp_group)
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().cpu()
return state_
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
original_shape: torch.Size, device: torch.device,
inplace: bool) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
if self.use_zero:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
# coding=utf-8 # coding=utf-8
import copy
import os import os
import re import re
from collections import abc as container_abcs from collections import abc as container_abcs
...@@ -10,10 +11,17 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple ...@@ -10,10 +11,17 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
...@@ -88,8 +96,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: ...@@ -88,8 +96,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False return False
def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
"""
Given the current shape of parameter and the shape of parameter before sharding,
return the dimension along which the parameter is sharded when using tensor parallel.
If tensor parallel is not used, return None.
Args:
current_shape (torch.Size): The current shape of parameter after sharding.
original_shape (torch.Size): The shape of parameter before sharding.
tp_size (int): The size of tp group.
Returns:
Optional[int]: The dimension along which parameter is partitioned.
"""
partition_dim = None
for dim, length in enumerate(original_shape):
if length > current_shape[dim]:
partition_dim = dim
break
if partition_dim is not None:
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
# ====================================== # ======================================
# Helper functions for saving shard file # Helper classes and functions for saving shard file
# ====================================== # ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper): def unwrap_optimizer(optimizer: OptimizerWrapper):
''' '''
...@@ -104,12 +139,97 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): ...@@ -104,12 +139,97 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim return unwrapped_optim
class StateDictSharder:
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
ret_block = None
ret_block_size = 0
# directly return if state is stored as distributed tensor
if isDTensor:
return ret_block, ret_block_size
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[param_id] = state
self.current_block_size += state_size
return ret_block, ret_block_size
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
"""
Gather the complete parameter for saving if passed in param is distributed under tp setting.
Args:
param (torch.Tensor): A model parameter, might be d_tensor.
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
Returns:
torch.Tensor: the complete parameter
"""
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
return to_global(param_)
elif is_customized_distributed_tensor(param_):
return to_global_for_customized_distributed_tensor(param_)
else:
return param_
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str, checkpoint: str,
index_file: "CheckpointIndexFile", index_file: "CheckpointIndexFile",
base_filename: str, base_filename: str,
is_master: bool, is_master: bool,
use_safetensors: bool = False) -> int: use_safetensors: bool = False,
use_pp_format: bool = False) -> int:
''' '''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args: Args:
...@@ -117,18 +237,21 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] ...@@ -117,18 +237,21 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
checkpoint (str): The path of checkpoint directory as string. checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated. index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards. base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is master. is_master (bool): Whether current rank is main process.
use_safetensors (bool): Whether to use safetensors to save checkpoint. use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns: Returns:
int: the total size of shards int: the total size of shards
''' '''
total_size = 0 total_size = 0
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict): for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
if not is_master: if not is_master:
del shard
continue continue
shard, current_size = shard_pair
shard_file = get_shard_filename(base_filename, idx) shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size total_size = total_size + current_size
for key in shard.keys(): for key in shard.keys():
...@@ -137,6 +260,11 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] ...@@ -137,6 +260,11 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
# Only save on master rank. # Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size return total_size
...@@ -146,28 +274,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) ...@@ -146,28 +274,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. given size.
""" """
current_block = {} state_dict_sharder = StateDictSharder(max_shard_size)
current_block_size = 0
for key, weight in state_dict.items(): for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if not is_distributed_tensor(weight): if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight) block, block_size = state_dict_sharder.append_param(key, weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
if ret_block != None: if block != None:
yield ret_block, ret_block_size yield block, block_size
yield current_block, current_block_size # Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
...@@ -178,47 +295,207 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> ...@@ -178,47 +295,207 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function. # Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state'] states = state_dict['state']
state_dict_sharder = StateDictSharder(max_shard_size)
current_block = {}
current_block_size = 0
for param_id, state in states.items(): for param_id, state in states.items():
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
if block != None:
yield block, block_size
ret_block = None # Return the last block in sharder.
ret_block_size = 0 yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class, # ======================================
# e.g., a SGD optimizer with momentum set to 0 can have None as state # Helper functions for saving state dict
# The calculation of tensor size should be skipped to avoid error. # ======================================
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor: def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def clean_folder(checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
Args:
checkpoint_path (str): Path to the checkpoint directory.
weights_name (str): Decides the prefix of filenames of weight shards.
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
is_master (bool, optional): Whether current rank is main process. Defaults to True.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
"""
if is_master:
for filename in os.listdir(checkpoint_path):
full_filename = os.path.join(checkpoint_path, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
if not use_pp_format:
reg = re.compile(r"(.*?)-\d{5}")
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
os.remove(full_filename)
def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
"""
Save config.json/generation_config.json if model is a Huggingface pretrained model.
This method can only be called when a model is saved in a sharded way.
Args:
model (nn.Module): The model whose config should be saved if it's a huggingface model.
checkpoint_path (str): Path to the checkpoint directory.
is_master (bool): Whether current rank is main process.
"""
if not isinstance(model, PreTrainedModel):
return
model = unwrap_huggingface_model(model)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype(model)
model.config.torch_dtype = str(dtype).split(".")[1]
# Attach architecture to the config
model.config.architectures = [model.__class__.__name__]
# Save the config
if is_master:
model.config.save_pretrained(checkpoint_path)
if model.can_generate():
model.generation_config.save_pretrained(checkpoint_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
only one tensor.
Args:
tensor (Tensor): tensor to be saved.
index_file (CheckpointIndexFile): path to the checkpoint file.
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
# create directory
output_root_path.mkdir(exist_ok=True)
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
# dtensor ckpt file always contains only one tensor
state_dict = {name: tensor}
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
"""
Get checkpoint file suffix.
Args:
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
else:
return '.bin'
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
prefix (str): prefix of the shard file name. Default: None.
if current_block_size + state_size > max_shard_size and current_block_size > 0: Returns:
ret_block = current_block str: checkpoint shard file name.
ret_block_size = current_block_size """
current_block = {} suffix = get_checkpoint_file_suffix(use_safetensors)
current_block_size = 0
current_block[param_id] = state if prefix is None:
current_block_size += state_size return f"{index:05d}-of-{total_number:05d}.{suffix}"
else:
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
"""
Generate dtensor file name.
Args:
param_name (str): name of the distributed parameter.
index (int): index of the shard.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
# ========================================
# Helper functions for loading state dict
# ========================================
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
...@@ -331,17 +608,21 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str ...@@ -331,17 +608,21 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
return id_map return id_map
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
r"""Copies states from `state_dict` into an Optimizer object. r"""Copies states from `state_dict` into an Optimizer object.
Args: Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): a mapping from tensor index (an integer) state_dict(dict): A mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor). to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): a mapping from tensor index (an integer) id_map(dict): A mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated. to its corresponding parameter (a tensor) whose states will be updated.
strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
""" """
# Ensure that the keys of state_dict are integers.
state_dict = {int(k): v for k, v in state_dict.items()}
def cast(param, value, key=None): def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param.""" r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
...@@ -368,7 +649,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d ...@@ -368,7 +649,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
if k in id_map: if k in id_map:
param = id_map[k] param = id_map[k]
new_states[param] = cast(param, v) new_states[param] = cast(param, v)
else: elif not strict:
new_states[k] = v new_states[k] = v
optimizer.state.update(new_states) optimizer.state.update(new_states)
...@@ -386,165 +667,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): ...@@ -386,165 +667,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
optimizer.defaults.setdefault('differentiable', False) optimizer.defaults.setdefault('differentiable', False)
# ======================================
# Helper functions for saving state dict
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
only one tensor.
Args:
tensor (Tensor): tensor to be saved.
index_file (CheckpointIndexFile): path to the checkpoint file.
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
# create directory
output_root_path.mkdir(exist_ok=True)
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
# dtensor ckpt file always contains only one tensor
state_dict = {name: tensor}
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
"""
Get checkpoint file suffix.
Args:
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
else:
return '.bin'
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
prefix (str): prefix of the shard file name. Default: None.
Returns:
str: checkpoint shard file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
if prefix is None:
return f"{index:05d}-of-{total_number:05d}.{suffix}"
else:
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
"""
Generate dtensor file name.
Args:
param_name (str): name of the distributed parameter.
index (int): index of the shard.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
def save_state_dict_as_shard(
state_dict: dict,
checkpoint_path: str,
index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None,
) -> None:
"""
Save state dict as shard.
Args:
state_dict (dict): state dict.
checkpoint_path (str): path to the checkpoint file.
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# generate the shard name
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
# save the shard
save_state_dict(state_dict, str(shard_file_path), use_safetensors)
# ========================================
# Helper functions for loading state dict
# ========================================
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
""" """
Check whether the checkpoint has an index file. Check whether the checkpoint has an index file.
...@@ -654,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int): ...@@ -654,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int):
get shard file name get shard file name
""" """
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file return shard_file
...@@ -94,17 +94,23 @@ class ProcessGroupMesh: ...@@ -94,17 +94,23 @@ class ProcessGroupMesh:
return np.unravel_index(rank, shape) return np.unravel_index(rank, shape)
@staticmethod @staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank. """Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args: Args:
coords (Tuple[int, ...]): Coordinate to be converted. coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh. shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns: Returns:
int: Rank of the coordinate. int: Rank of the coordinate.
""" """
return np.ravel_multi_index(coord, shape)
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created. """Get the process group with the given ranks. It the process group doesn't exist, it will be created.
......
...@@ -173,14 +173,10 @@ class PipelineP2PCommunication: ...@@ -173,14 +173,10 @@ class PipelineP2PCommunication:
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage(): if prev_rank is None:
input_tensor = None prev_rank = self.stage_manager.get_prev_rank()
else: cur_rank = self.stage_manager.get_rank()
if prev_rank is None: input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
return input_tensor return input_tensor
...@@ -193,14 +189,11 @@ class PipelineP2PCommunication: ...@@ -193,14 +189,11 @@ class PipelineP2PCommunication:
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage(): if next_rank is None:
output_tensor_grad = None next_rank = self.stage_manager.get_next_rank()
else: cur_rank = self.stage_manager.get_rank()
if next_rank is None: output_tensor_grad = _recv_object(next_rank, cur_rank,
next_rank = self.stage_manager.get_next_rank() self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
return output_tensor_grad return output_tensor_grad
...@@ -211,12 +204,10 @@ class PipelineP2PCommunication: ...@@ -211,12 +204,10 @@ class PipelineP2PCommunication:
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if next_rank is None:
if next_rank is None: next_rank = self.stage_manager.get_next_rank()
next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank()
cur_rank = self.stage_manager.get_rank() _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
_send_object(output_object, cur_rank, next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
...@@ -225,9 +216,7 @@ class PipelineP2PCommunication: ...@@ -225,9 +216,7 @@ class PipelineP2PCommunication:
input_object (Any): Object to be sent. input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(): if prev_rank is None:
if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank()
prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank()
cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
_send_object(input_object, cur_rank, prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
from typing import Any, List, Optional from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch import torch
import torch.cuda import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch.utils._pytree import (
SUPPORTED_NODES,
LeafSpec,
TreeSpec,
_is_leaf,
_register_pytree_node,
tree_flatten,
tree_map,
tree_unflatten,
)
# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
flat_args, spec = tree_flatten_hf(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values an a TreeSpec that can be used
to reconstruct the pytree.
"""
if isinstance(pytree, OrderedDict):
node_type = OrderedDict
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result: List[Any] = []
children_specs: List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten_hf(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
else:
result, tree_spec = tree_flatten(pytree)
return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any: def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
...@@ -104,7 +154,7 @@ def detach(x: Any) -> Any: ...@@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x return x
def merge_batch(data: List[Any]) -> Any: def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch. """Merge micro batches into a batch.
Args: Args:
...@@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any: ...@@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = [] flattened_data = []
tree_spec = None tree_spec = None
for d in data: for d in data:
elems, tree_spec = tree_flatten(d) # elems should be an instance of OrderedDict
elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems) flattened_data.append(elems)
merged_data = [] merged_data = []
for elem_batch in zip(*flattened_data): for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor): if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0)) if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else: else:
merged_data.append(list(elem_batch)) merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec) return tree_unflatten(merged_data, tree_spec)
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
self.num_model_chunks = num_model_chunks
assert num_microbatches % self.num_model_chunks == 0, \
"Number of microbatches should be an integer multiple of number of model chunks"
super().__init__(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Args:
microbatch_id (int): the current model chunk idx.
Returns:
Any: Micro batch.
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
microbatch_id (int): the current microbatch idx
forward (bool): if is the forward process
Returns:
int: The model chunk idx of the input microbatch_id
"""
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not forward:
model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def is_first_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the first stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
return True
return False
def is_last_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the last stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
return True
return False
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.is_first_stage(model_chunk_id):
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.is_last_stage(model_chunk_id):
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.is_last_stage(model_chunk_id):
self.comm.send_forward(output_object, next_rank)
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.is_first_stage(model_chunk_id):
self.comm.send_backward(input_object, prev_rank)
def forward_step(self,
model_chunk: Module,
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
if self.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
"""Backward one step of the pipeline
Args:
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
Returns:
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
"""
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
input_obj_grad = {}
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(self,
model_chunk: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model_chunk (List[Module]): Model Chunk to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
# num_warmup_microbatches is the step when not all the processes are working
num_microbatches = self.num_microbatches * num_model_chunks
if forward_only:
num_warmup_microbatches = num_microbatches
else:
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not forward_only:
input_objs = [[] for _ in range(num_model_chunks)]
output_objs = [[] for _ in range(num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# for ranks except the first one, get into recv state
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
input_obj = self.recv_forward(0)
input_objs[0].append(input_obj)
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)
# recv first on first rank to avoid sending or recving at the same time
if self.stage_manager.is_first_stage():
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
break
else:
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
last_iteration = (i == (num_microbatches_remaining - 1))
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(model_chunk_id, output_obj)
if not last_iteration:
input_obj = self.recv_forward(model_chunk_id)
else:
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_microbatches_remaining, num_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=False)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
return {'loss': accum_loss, 'outputs': outputs}
...@@ -6,25 +6,47 @@ import torch.cuda ...@@ -6,25 +6,47 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule): class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: def __init__(self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None) -> None:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super().__init__(stage_manager) super().__init__(stage_manager)
assert num_microbatches is not None or microbatch_size is not None, \
"Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None self._use_microbatch_size = num_microbatches is None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -39,9 +61,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -39,9 +61,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0 self.microbatch_offset = 0
assert self.batch_size % self.num_microbatches == 0, \ if not self._use_microbatch_size:
"Batch size should divided by the number of microbatches" assert self.batch_size % self.num_microbatches == 0, \
self.microbatch_size = self.batch_size // self.num_microbatches "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
assert self.batch_size % self.microbatch_size == 0, \
"Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
def load_micro_batch(self) -> Any: def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
...@@ -53,6 +80,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -53,6 +80,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For 1F1B.
Args:
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For 1F1B.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def forward_step(self, def forward_step(self,
model: Module, model: Module,
input_obj: Optional[dict], input_obj: Optional[dict],
...@@ -72,16 +155,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -72,16 +155,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
""" """
micro_batch = self.load_micro_batch() micro_batch = self.load_micro_batch()
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj) output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage(): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
if outputs is not None: if outputs is not None:
outputs.append(tree_map(detach, output_obj)) outputs.append(tree_map_hf(detach, output_obj))
return loss return loss
else: else:
return output_obj return output_obj
...@@ -102,7 +185,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -102,7 +185,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Retain the grad on the input_obj. # Retain the grad on the input_obj.
tree_map(retain_grad, input_obj) tree_map(retain_grad, input_obj)
# Backward pass. # Backward pass.
if output_obj_grad is None: if output_obj_grad is None:
optimizer.backward(output_obj) optimizer.backward(output_obj)
...@@ -171,11 +253,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -171,11 +253,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.comm.send_forward(output_obj) self.send_forward(output_obj)
if not forward_only: if not forward_only:
input_objs.append(input_obj) input_objs.append(input_obj)
...@@ -185,7 +267,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -185,7 +267,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -193,15 +275,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -193,15 +275,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only: if forward_only:
self.comm.send_forward(output_obj) self.send_forward(output_obj)
if not last_iteration: if not last_iteration:
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
else: else:
# TODO adjust here # TODO adjust here
self.comm.send_forward(output_obj) self.send_forward(output_obj)
output_obj_grad = self.comm.recv_backward() output_obj_grad = self.recv_backward()
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
...@@ -216,8 +298,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -216,8 +298,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if last_iteration: if last_iteration:
input_obj = None input_obj = None
else: else:
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
self.comm.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -225,10 +307,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -225,10 +307,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
output_obj_grad = self.comm.recv_backward() output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.comm.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs} return {'loss': accum_loss, 'outputs': outputs}
...@@ -17,28 +17,24 @@ class PipelineStageManager: ...@@ -17,28 +17,24 @@ class PipelineStageManager:
Attributes: Attributes:
num_stages (int): Number of stages in the pipeline. num_stages (int): Number of stages in the pipeline.
stage (int): The current stage. stage (int): The current stage.
num_virtual_stages (int): Number of virtual stages in the pipeline.
virtual_stage (int): The current virtual stage.
""" """
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
self.pg_mesh = pg_mesh self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis self.pipeline_axis = pipeline_axis
self.num_virtual_stages: Optional[int] = None
self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord # init prev and next coord
coord = self.pg_mesh.coordinate() coord = self.pg_mesh.coordinate()
if self.stage > 0: # the prev rank of rank0 is the last rank
prev_coord = coord[: self.pipeline_axis] + \ prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap')
if self.stage < self.num_stages - 1: # the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + \ next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap')
# init p2p process groups # init p2p process groups
stages = list(range(self.num_stages)) stages = list(range(self.num_stages))
...@@ -48,32 +44,28 @@ class PipelineStageManager: ...@@ -48,32 +44,28 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self, virtual: bool = False) -> bool: if is_virtual:
"""Is the current stage the first stage. # add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
Args: def is_first_stage(self) -> bool:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False. """Is the current stage the first stage.
Returns: Returns:
bool: Whether the current stage is the first stage. bool: Whether the current stage is the first stage.
""" """
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == 0
return self.stage == 0 return self.stage == 0
def is_last_stage(self, virtual: bool = False) -> bool: def is_last_stage(self) -> bool:
"""Is the current stage the last stage. """Is the current stage the last stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns: Returns:
bool: Whether the current stage is the last stage. bool: Whether the current stage is the last stage.
""" """
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1 return self.stage == self.num_stages - 1
@property @property
...@@ -108,7 +100,6 @@ class PipelineStageManager: ...@@ -108,7 +100,6 @@ class PipelineStageManager:
Returns: Returns:
int: Rank of the previous stage. int: Rank of the previous stage.
""" """
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank return self.prev_rank
def get_next_rank(self) -> int: def get_next_rank(self) -> int:
...@@ -117,39 +108,8 @@ class PipelineStageManager: ...@@ -117,39 +108,8 @@ class PipelineStageManager:
Returns: Returns:
int: Rank of the next stage. int: Rank of the next stage.
""" """
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank return self.next_rank
def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
"""Set the number of virtual stages.
Args:
num_virtual_stages (int): Number of virtual stages.
"""
self.num_virtual_stages = num_virtual_stages
def set_virtual_stage(self, virtual_stage: int) -> None:
"""Set the virtual stage.
Args:
virtual_stage (int): Virtual stage.
"""
self.virtual_stage = virtual_stage
@contextmanager
def switch_virtual_stage(self, virtual_stage: int) -> None:
"""A context manager to switch virtual stage.
Args:
virtual_stage (int): Target virtual stage.
"""
old_stage = self.virtual_stage
try:
self.set_virtual_stage(virtual_stage)
yield
finally:
self.set_virtual_stage(old_stage)
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter. """Get the p2p process group between two ranks. The order of the two ranks does not matter.
......
...@@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate ...@@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate
### Convergence ### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: | | :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True | | 0.84589 | 0.88613 | 0.43414 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True | | 0.83594 | 0.88064 | 0.43298 | 1 | False |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence. Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
...@@ -141,6 +143,240 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ...@@ -141,6 +143,240 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.dim = dim
ctx.process_group = process_group
# do reduce-scatter
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
return _gather(grad_output, dim, process_group), None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
This class is designed for matmul operation with gather forward and reduce-scatter backward.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
Split the input and keep only the corresponding chuck to the rank. Split the input and keep only the corresponding chuck to the rank.
...@@ -200,6 +436,26 @@ class _ReduceBackward(torch.autograd.Function): ...@@ -200,6 +436,26 @@ class _ReduceBackward(torch.autograd.Function):
return _reduce(grad_output, ctx.process_group), None return _reduce(grad_output, ctx.process_group), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def _reduce(input_, process_group): def _reduce(input_, process_group):
# skip if only one rank involved # skip if only one rank involved
if dist.get_world_size(process_group) == 1: if dist.get_world_size(process_group) == 1:
...@@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None):
return input_ return input_
# all gather # all gather
rank = dist.get_rank(process_group) input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group) torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat # concat
...@@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None):
return output return output
class _GatherForwardSplitBackward(torch.autograd.Function): def _reduce_scatter(input_, dim=1, process_group=None):
"""Gather the input from model parallel region and concatenate. """ Do reduce-scatter operation.
Args: Args:
input_: input matrix. input_ (`torch.Tensor`): The input tensor from sequence parallel region.
parallel_mode: parallel mode. dim (int): The dimension to perform reduce-scatter.
dim: dimension process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
""" """
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
@staticmethod # reduce-scatter
def forward(ctx, input_, dim, process_group): new_shape = list(input_.shape)
ctx.process_group = process_group assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
ctx.dim = dim f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
return _gather(input_, dim, process_group) new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)
@staticmethod return output
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
...@@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre ...@@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def gather_forward_split_backward(input_, dim, process_group): def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group) return _GatherForwardSplitBackward.apply(input_, dim, process_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment