Unverified Commit efba0f44 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge pull request #4612 from hpcaitech/feature/shardformer

[shardformer] update hybrid parallel plugin and fix bugs
parents ac178ca5 fae6c92e
...@@ -208,7 +208,7 @@ jobs: ...@@ -208,7 +208,7 @@ jobs:
- name: Execute Unit Testing - name: Execute Unit Testing
run: | run: |
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/ CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
......
...@@ -44,7 +44,7 @@ jobs: ...@@ -44,7 +44,7 @@ jobs:
name: Test for PyTorch Compatibility name: Test for PyTorch Compatibility
needs: matrix_preparation needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu] runs-on: [self-hosted, 8-gpu]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
......
...@@ -35,7 +35,7 @@ jobs: ...@@ -35,7 +35,7 @@ jobs:
name: Test for PyTorch Compatibility name: Test for PyTorch Compatibility
needs: matrix_preparation needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu] runs-on: [self-hosted, 8-gpu]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
......
...@@ -32,7 +32,7 @@ jobs: ...@@ -32,7 +32,7 @@ jobs:
name: Test for PyTorch Compatibility name: Test for PyTorch Compatibility
needs: matrix_preparation needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI' if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu] runs-on: [self-hosted, 8-gpu]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
......
...@@ -14,29 +14,43 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer ...@@ -14,29 +14,43 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [ SFT_DATASET = [
{ {
"instruction": "Provide a list of the top 10 most popular mobile games in Asia", "instruction":
"input": "", "Provide a list of the top 10 most popular mobile games in Asia",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", "input":
"id": 0 "",
"output":
"The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id":
0
}, },
{ {
"instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", "instruction":
"input": "", "Please provide an action plan for reducing carbon footprint on a corporate level",
"output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", "input":
"id": 1 "",
"output":
"An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id":
1
}, },
{ {
"instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", "instruction":
"input": "", "Write a persuasive email to your boss explaining why you should have a pay raise",
"output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", "input":
"id": 2 "",
"output":
"Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id":
2
}, },
] ]
PROMPT_DATASET = [ PROMPT_DATASET = [
{ {
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", "instruction":
"id": 0 "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
"id":
0
}, },
{ {
"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
...@@ -73,9 +87,7 @@ def make_tokenizer(model: str): ...@@ -73,9 +87,7 @@ def make_tokenizer(model: str):
return tokenizer return tokenizer
def check_content(input_ids_stripped: torch.Tensor, def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
tokenizer: PreTrainedTokenizer,
model: str):
if model == "opt": if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt. # NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id assert input_ids_stripped[0] == tokenizer.eos_token_id
...@@ -98,13 +110,10 @@ def check_content(input_ids_stripped: torch.Tensor, ...@@ -98,13 +110,10 @@ def check_content(input_ids_stripped: torch.Tensor,
assert input_ids_stripped != tokenizer.mask_token_id assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2]) @pytest.mark.parametrize("max_datasets_size", [2])
def test_prompt_dataset(model: str, def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
max_datasets_size: int,
max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json" dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f: with open(os.path.join(tmp_dir, dataset_name), "w") as f:
...@@ -127,19 +136,12 @@ def test_prompt_dataset(model: str, ...@@ -127,19 +136,12 @@ def test_prompt_dataset(model: str,
check_content(input_ids.masked_select(attention_mask), tokenizer, model) check_content(input_ids.masked_select(attention_mask), tokenizer, model)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize(["dataset_path", "subset"], [ @pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)])
("Dahoas/rm-static", None)
])
@pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str, def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
dataset_path: str,
subset: Optional[str],
max_datasets_size: int,
max_length: int):
data = load_dataset(dataset_path, data_dir=subset) data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \ assert max_datasets_size <= len(data["train"]) \
and max_datasets_size <= len(data["test"]) and max_datasets_size <= len(data["test"])
...@@ -196,15 +198,12 @@ def test_reward_dataset(model: str, ...@@ -196,15 +198,12 @@ def test_reward_dataset(model: str,
assert torch.all(r_mask) assert torch.all(r_mask)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
def test_sft_dataset(model: str, def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
dataset_path: Optional[str],
max_dataset_size: int,
max_length: int):
tokenizer = make_tokenizer(model) tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct": if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions") data = load_dataset(dataset_path, "super_natural_instructions")
...@@ -253,10 +252,7 @@ def test_sft_dataset(model: str, ...@@ -253,10 +252,7 @@ def test_sft_dataset(model: str,
if __name__ == "__main__": if __name__ == "__main__":
test_sft_dataset(model="bloom", test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
dataset_path="yizhongw/self_instruct",
max_dataset_size=2,
max_length=256)
test_reward_dataset(model="gpt2", test_reward_dataset(model="gpt2",
dataset_path="Anthropic/hh-rlhf", dataset_path="Anthropic/hh-rlhf",
...@@ -267,3 +263,4 @@ if __name__ == "__main__": ...@@ -267,3 +263,4 @@ if __name__ == "__main__":
test_prompt_dataset(model="opt", test_prompt_dataset(model="opt",
max_datasets_size=2, max_datasets_size=2,
max_length=128) max_length=128)
...@@ -16,10 +16,11 @@ from coati.models.opt import OPTRM, OPTActor, OPTCritic ...@@ -16,10 +16,11 @@ from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32]) @pytest.mark.parametrize("seq_len", [32])
@pytest.mark.parametrize("actor_maker", [ @pytest.mark.parametrize(
"actor_maker",
[
lambda: BLOOMActor(), lambda: BLOOMActor(),
lambda: GPTActor(), lambda: GPTActor(),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
...@@ -27,6 +28,7 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer ...@@ -27,6 +28,7 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
lambda: OPTActor(), lambda: OPTActor(),
# lambda: ChatGLMActor(), # lambda: ChatGLMActor(),
]) ])
@pytest.mark.parametrize("generate_kwargs", [{ @pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64, "max_length": 64,
"use_cache": True, "use_cache": True,
...@@ -34,23 +36,15 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer ...@@ -34,23 +36,15 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
"temperature": 1.0, "temperature": 1.0,
"top_k": 50, "top_k": 50,
}]) }])
def test_generation(actor_maker: Callable[[], Actor], def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
batch_size: int,
seq_len: int,
generate_kwargs: Dict[str, Any]
):
actor = actor_maker() actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs) sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"]) assert sequences.shape == (batch_size, generate_kwargs["max_length"])
@pytest.mark.cpu
def test_utils(): def test_utils():
fn_input = { fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
"tensor": torch.ones((10, )),
"mask": torch.randint(0, 2, (10, ))
}
fn_output = masked_mean(dim=0, **fn_input) fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0 assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0)) assert torch.allclose(fn_output, torch.tensor(1.0))
...@@ -58,14 +52,14 @@ def test_utils(): ...@@ -58,14 +52,14 @@ def test_utils():
batch_size = 4 batch_size = 4
num_labels = 10 num_labels = 10
fn_input = { fn_input = {
"r": torch.ones((batch_size, )), "r": torch.ones((batch_size,)),
"kl_coef": 1.0, "kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)), "log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels)) "action_mask": torch.randint(0, 2, (batch_size, num_labels))
} }
fn_output = compute_reward(**fn_input) fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size, ) assert fn_output.shape == (batch_size,)
batch_size = 4 batch_size = 4
seq_len = 32 seq_len = 32
...@@ -82,17 +76,11 @@ def test_utils(): ...@@ -82,17 +76,11 @@ def test_utils():
assert fn_output.shape == (batch_size, num_actions) assert fn_output.shape == (batch_size, num_actions)
@pytest.mark.cpu
@pytest.mark.parametrize("lora_rank", [4]) @pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32]) @pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4]) @pytest.mark.parametrize("num_layers", [4])
def test_lora(lora_rank: int, def test_lora(lora_rank: int, num_dim: int, num_layers: int):
num_dim: int, model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
num_layers: int):
model = nn.ModuleList(
[nn.Linear(num_dim, num_dim)
for _ in range(num_layers)]
)
lora_model = convert_to_lora_module(model, lora_rank) lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList) assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers): for i in range(num_layers):
...@@ -105,8 +93,7 @@ def test_lora(lora_rank: int, ...@@ -105,8 +93,7 @@ def test_lora(lora_rank: int,
assert isinstance(lora_model[i], LoraLinear) assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias) assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters()) optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim) x = torch.randn(8, num_dim)
for i in range(num_layers): for i in range(num_layers):
...@@ -122,10 +109,11 @@ def test_lora(lora_rank: int, ...@@ -122,10 +109,11 @@ def test_lora(lora_rank: int,
lora_model[i].lora_B @ lora_model[i].lora_A) lora_model[i].lora_B @ lora_model[i].lora_A)
@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("models_maker", [ @pytest.mark.parametrize(
"models_maker",
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()), lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
...@@ -178,13 +166,10 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], ...@@ -178,13 +166,10 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
assert rm_output.shape == (batch_size, ) assert rm_output.shape == (batch_size, )
@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100]) @pytest.mark.parametrize("num_labels", [100])
def test_loss(batch_size: int, def test_loss(batch_size: int, seq_len: int, num_labels: int):
seq_len: int,
num_labels: int):
loss = GPTLMLoss() loss = GPTLMLoss()
loss_input = { loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels), "logits": torch.randn(batch_size, seq_len, num_labels),
...@@ -194,54 +179,43 @@ def test_loss(batch_size: int, ...@@ -194,54 +179,43 @@ def test_loss(batch_size: int,
loss = PolicyLoss() loss = PolicyLoss()
loss_input = { loss_input = {
"log_probs": torch.randn(batch_size, ), "log_probs": torch.randn(batch_size,),
"old_log_probs": torch.randn(batch_size, ), "old_log_probs": torch.randn(batch_size,),
"advantages": torch.randn(batch_size, ) "advantages": torch.randn(batch_size,)
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
loss = ValueLoss() loss = ValueLoss()
loss_input = { loss_input = {
"values": torch.randn(batch_size, ), "values": torch.randn(batch_size,),
"old_values": torch.randn(batch_size, ), "old_values": torch.randn(batch_size,),
"reward": torch.randn(batch_size, ) "reward": torch.randn(batch_size,)
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
loss = LogSigLoss() loss = LogSigLoss()
loss_input = { loss_input = {
"chosen_reward": torch.randn(batch_size, ), "chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size, ), "reject_reward": torch.randn(batch_size,),
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
loss = LogExpLoss() loss = LogExpLoss()
loss_input = { loss_input = {
"chosen_reward": torch.randn(batch_size, ), "chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size, ), "reject_reward": torch.randn(batch_size,),
} }
loss_output = loss(**loss_input) loss_output = loss(**loss_input)
if __name__ == "__main__": if __name__ == "__main__":
generate_kwargs = dict(max_length=40, generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
use_cache=True, test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
do_sample=True,
temperature=1.0,
top_k=50)
test_generation(lambda: LlamaActor(),
batch_size=4,
seq_len=32,
generate_kwargs=generate_kwargs)
test_utils() test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2) test_lora(lora_rank=2, num_dim=8, num_layers=2)
test_models(models_maker=lambda: (BLOOMActor(), test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
BLOOMCritic(),
BLOOMRM()),
batch_size=8,
seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100) test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
...@@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import ( ...@@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
load_shard_state_dict, load_shard_state_dict,
save_config_file,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
) )
...@@ -107,6 +108,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -107,6 +108,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. " logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.") f"index located at {save_index_file}.")
......
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
...@@ -23,6 +23,7 @@ from .utils import ( ...@@ -23,6 +23,7 @@ from .utils import (
load_state_dict, load_state_dict,
load_state_dict_into_model, load_state_dict_into_model,
load_states_into_optimizer, load_states_into_optimizer,
save_config_file,
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
...@@ -183,6 +184,7 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -183,6 +184,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. " logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.") f"index located at {save_index_file}.")
......
This diff is collapsed.
This diff is collapsed.
...@@ -94,17 +94,23 @@ class ProcessGroupMesh: ...@@ -94,17 +94,23 @@ class ProcessGroupMesh:
return np.unravel_index(rank, shape) return np.unravel_index(rank, shape)
@staticmethod @staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank. """Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args: Args:
coords (Tuple[int, ...]): Coordinate to be converted. coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh. shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns: Returns:
int: Rank of the coordinate. int: Rank of the coordinate.
""" """
return np.ravel_multi_index(coord, shape)
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created. """Get the process group with the given ranks. It the process group doesn't exist, it will be created.
......
...@@ -173,14 +173,10 @@ class PipelineP2PCommunication: ...@@ -173,14 +173,10 @@ class PipelineP2PCommunication:
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage():
input_tensor = None
else:
if prev_rank is None: if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
return input_tensor return input_tensor
...@@ -193,9 +189,6 @@ class PipelineP2PCommunication: ...@@ -193,9 +189,6 @@ class PipelineP2PCommunication:
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
if next_rank is None: if next_rank is None:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
...@@ -211,12 +204,10 @@ class PipelineP2PCommunication: ...@@ -211,12 +204,10 @@ class PipelineP2PCommunication:
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage():
if next_rank is None: if next_rank is None:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
...@@ -225,9 +216,7 @@ class PipelineP2PCommunication: ...@@ -225,9 +216,7 @@ class PipelineP2PCommunication:
input_object (Any): Object to be sent. input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage():
if prev_rank is None: if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
from typing import Any, List, Optional from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch import torch
import torch.cuda import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch.utils._pytree import (
SUPPORTED_NODES,
LeafSpec,
TreeSpec,
_is_leaf,
_register_pytree_node,
tree_flatten,
tree_map,
tree_unflatten,
)
# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
flat_args, spec = tree_flatten_hf(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values an a TreeSpec that can be used
to reconstruct the pytree.
"""
if isinstance(pytree, OrderedDict):
node_type = OrderedDict
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result: List[Any] = []
children_specs: List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten_hf(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
else:
result, tree_spec = tree_flatten(pytree)
return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any: def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
...@@ -104,7 +154,7 @@ def detach(x: Any) -> Any: ...@@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x return x
def merge_batch(data: List[Any]) -> Any: def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch. """Merge micro batches into a batch.
Args: Args:
...@@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any: ...@@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = [] flattened_data = []
tree_spec = None tree_spec = None
for d in data: for d in data:
elems, tree_spec = tree_flatten(d) # elems should be an instance of OrderedDict
elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems) flattened_data.append(elems)
merged_data = [] merged_data = []
for elem_batch in zip(*flattened_data): for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor): if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0)) if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else: else:
merged_data.append(list(elem_batch)) merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec) return tree_unflatten(merged_data, tree_spec)
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
self.num_model_chunks = num_model_chunks
assert num_microbatches % self.num_model_chunks == 0, \
"Number of microbatches should be an integer multiple of number of model chunks"
super().__init__(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Args:
microbatch_id (int): the current model chunk idx.
Returns:
Any: Micro batch.
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
microbatch_id (int): the current microbatch idx
forward (bool): if is the forward process
Returns:
int: The model chunk idx of the input microbatch_id
"""
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not forward:
model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def is_first_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the first stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
return True
return False
def is_last_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the last stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
return True
return False
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.is_first_stage(model_chunk_id):
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.is_last_stage(model_chunk_id):
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.is_last_stage(model_chunk_id):
self.comm.send_forward(output_object, next_rank)
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.is_first_stage(model_chunk_id):
self.comm.send_backward(input_object, prev_rank)
def forward_step(self,
model_chunk: Module,
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
if self.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
"""Backward one step of the pipeline
Args:
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
Returns:
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
"""
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
input_obj_grad = {}
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(self,
model_chunk: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model_chunk (List[Module]): Model Chunk to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
# num_warmup_microbatches is the step when not all the processes are working
num_microbatches = self.num_microbatches * num_model_chunks
if forward_only:
num_warmup_microbatches = num_microbatches
else:
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not forward_only:
input_objs = [[] for _ in range(num_model_chunks)]
output_objs = [[] for _ in range(num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# for ranks except the first one, get into recv state
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
input_obj = self.recv_forward(0)
input_objs[0].append(input_obj)
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)
# recv first on first rank to avoid sending or recving at the same time
if self.stage_manager.is_first_stage():
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
break
else:
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
last_iteration = (i == (num_microbatches_remaining - 1))
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(model_chunk_id, output_obj)
if not last_iteration:
input_obj = self.recv_forward(model_chunk_id)
else:
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_microbatches_remaining, num_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=False)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
return {'loss': accum_loss, 'outputs': outputs}
...@@ -6,25 +6,47 @@ import torch.cuda ...@@ -6,25 +6,47 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule): class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: def __init__(self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None) -> None:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super().__init__(stage_manager) super().__init__(stage_manager)
assert num_microbatches is not None or microbatch_size is not None, \
"Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None self._use_microbatch_size = num_microbatches is None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -39,9 +61,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -39,9 +61,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0 self.microbatch_offset = 0
if not self._use_microbatch_size:
assert self.batch_size % self.num_microbatches == 0, \ assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches" "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches self.microbatch_size = self.batch_size // self.num_microbatches
else:
assert self.batch_size % self.microbatch_size == 0, \
"Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
def load_micro_batch(self) -> Any: def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
...@@ -53,6 +80,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -53,6 +80,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For 1F1B.
Args:
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For 1F1B.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def forward_step(self, def forward_step(self,
model: Module, model: Module,
input_obj: Optional[dict], input_obj: Optional[dict],
...@@ -72,16 +155,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -72,16 +155,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
""" """
micro_batch = self.load_micro_batch() micro_batch = self.load_micro_batch()
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj) output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage(): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
if outputs is not None: if outputs is not None:
outputs.append(tree_map(detach, output_obj)) outputs.append(tree_map_hf(detach, output_obj))
return loss return loss
else: else:
return output_obj return output_obj
...@@ -102,7 +185,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -102,7 +185,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Retain the grad on the input_obj. # Retain the grad on the input_obj.
tree_map(retain_grad, input_obj) tree_map(retain_grad, input_obj)
# Backward pass. # Backward pass.
if output_obj_grad is None: if output_obj_grad is None:
optimizer.backward(output_obj) optimizer.backward(output_obj)
...@@ -171,11 +253,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -171,11 +253,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.comm.send_forward(output_obj) self.send_forward(output_obj)
if not forward_only: if not forward_only:
input_objs.append(input_obj) input_objs.append(input_obj)
...@@ -185,7 +267,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -185,7 +267,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -193,15 +275,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -193,15 +275,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only: if forward_only:
self.comm.send_forward(output_obj) self.send_forward(output_obj)
if not last_iteration: if not last_iteration:
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
else: else:
# TODO adjust here # TODO adjust here
self.comm.send_forward(output_obj) self.send_forward(output_obj)
output_obj_grad = self.comm.recv_backward() output_obj_grad = self.recv_backward()
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
...@@ -216,8 +298,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -216,8 +298,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if last_iteration: if last_iteration:
input_obj = None input_obj = None
else: else:
input_obj = self.comm.recv_forward() input_obj = self.recv_forward()
self.comm.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -225,10 +307,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -225,10 +307,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
output_obj_grad = self.comm.recv_backward() output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.comm.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs} return {'loss': accum_loss, 'outputs': outputs}
...@@ -17,28 +17,24 @@ class PipelineStageManager: ...@@ -17,28 +17,24 @@ class PipelineStageManager:
Attributes: Attributes:
num_stages (int): Number of stages in the pipeline. num_stages (int): Number of stages in the pipeline.
stage (int): The current stage. stage (int): The current stage.
num_virtual_stages (int): Number of virtual stages in the pipeline.
virtual_stage (int): The current virtual stage.
""" """
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
self.pg_mesh = pg_mesh self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis self.pipeline_axis = pipeline_axis
self.num_virtual_stages: Optional[int] = None
self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord # init prev and next coord
coord = self.pg_mesh.coordinate() coord = self.pg_mesh.coordinate()
if self.stage > 0: # the prev rank of rank0 is the last rank
prev_coord = coord[: self.pipeline_axis] + \ prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap')
if self.stage < self.num_stages - 1: # the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + \ next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap')
# init p2p process groups # init p2p process groups
stages = list(range(self.num_stages)) stages = list(range(self.num_stages))
...@@ -48,32 +44,28 @@ class PipelineStageManager: ...@@ -48,32 +44,28 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self, virtual: bool = False) -> bool: if is_virtual:
"""Is the current stage the first stage. # add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
Args: def is_first_stage(self) -> bool:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False. """Is the current stage the first stage.
Returns: Returns:
bool: Whether the current stage is the first stage. bool: Whether the current stage is the first stage.
""" """
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == 0
return self.stage == 0 return self.stage == 0
def is_last_stage(self, virtual: bool = False) -> bool: def is_last_stage(self) -> bool:
"""Is the current stage the last stage. """Is the current stage the last stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns: Returns:
bool: Whether the current stage is the last stage. bool: Whether the current stage is the last stage.
""" """
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1 return self.stage == self.num_stages - 1
@property @property
...@@ -108,7 +100,6 @@ class PipelineStageManager: ...@@ -108,7 +100,6 @@ class PipelineStageManager:
Returns: Returns:
int: Rank of the previous stage. int: Rank of the previous stage.
""" """
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank return self.prev_rank
def get_next_rank(self) -> int: def get_next_rank(self) -> int:
...@@ -117,39 +108,8 @@ class PipelineStageManager: ...@@ -117,39 +108,8 @@ class PipelineStageManager:
Returns: Returns:
int: Rank of the next stage. int: Rank of the next stage.
""" """
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank return self.next_rank
def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
"""Set the number of virtual stages.
Args:
num_virtual_stages (int): Number of virtual stages.
"""
self.num_virtual_stages = num_virtual_stages
def set_virtual_stage(self, virtual_stage: int) -> None:
"""Set the virtual stage.
Args:
virtual_stage (int): Virtual stage.
"""
self.virtual_stage = virtual_stage
@contextmanager
def switch_virtual_stage(self, virtual_stage: int) -> None:
"""A context manager to switch virtual stage.
Args:
virtual_stage (int): Target virtual stage.
"""
old_stage = self.virtual_stage
try:
self.set_virtual_stage(virtual_stage)
yield
finally:
self.set_virtual_stage(old_stage)
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter. """Get the p2p process group between two ranks. The order of the two ranks does not matter.
......
...@@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate ...@@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate
### Convergence ### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: | | :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True | | 0.84589 | 0.88613 | 0.43414 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True | | 0.83594 | 0.88064 | 0.43298 | 1 | False |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence. Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
...@@ -141,6 +143,240 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ...@@ -141,6 +143,240 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.dim = dim
ctx.process_group = process_group
# do reduce-scatter
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
return _gather(grad_output, dim, process_group), None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
This class is designed for matmul operation with gather forward and reduce-scatter backward.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
Split the input and keep only the corresponding chuck to the rank. Split the input and keep only the corresponding chuck to the rank.
...@@ -200,6 +436,26 @@ class _ReduceBackward(torch.autograd.Function): ...@@ -200,6 +436,26 @@ class _ReduceBackward(torch.autograd.Function):
return _reduce(grad_output, ctx.process_group), None return _reduce(grad_output, ctx.process_group), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def _reduce(input_, process_group): def _reduce(input_, process_group):
# skip if only one rank involved # skip if only one rank involved
if dist.get_world_size(process_group) == 1: if dist.get_world_size(process_group) == 1:
...@@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None):
return input_ return input_
# all gather # all gather
rank = dist.get_rank(process_group) input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group) torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat # concat
...@@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None):
return output return output
class _GatherForwardSplitBackward(torch.autograd.Function): def _reduce_scatter(input_, dim=1, process_group=None):
"""Gather the input from model parallel region and concatenate. """ Do reduce-scatter operation.
Args: Args:
input_: input matrix. input_ (`torch.Tensor`): The input tensor from sequence parallel region.
parallel_mode: parallel mode. dim (int): The dimension to perform reduce-scatter.
dim: dimension process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
""" """
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
@staticmethod # reduce-scatter
def forward(ctx, input_, dim, process_group): new_shape = list(input_.shape)
ctx.process_group = process_group assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
ctx.dim = dim f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
return _gather(input_, dim, process_group) new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)
@staticmethod return output
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
...@@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre ...@@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def gather_forward_split_backward(input_, dim, process_group): def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group) return _GatherForwardSplitBackward.apply(input_, dim, process_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment