Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -22,10 +22,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: ...@@ -22,10 +22,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
def train_step(strategy: Strategy, def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
actor: GPTActor,
actor_optim: HybridAdam,
batch_size: int = 8):
data = get_data(batch_size) data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_output = actor(data["input_ids"], data["attention_mask"]) actor_output = actor(data["input_ids"], data["attention_mask"])
...@@ -35,8 +32,7 @@ def train_step(strategy: Strategy, ...@@ -35,8 +32,7 @@ def train_step(strategy: Strategy,
strategy.optimizer_step(actor_optim) strategy.optimizer_step(actor_optim)
def run_test_checkpoint(strategy_name: str, def run_test_checkpoint(strategy_name: str, shard: bool):
shard: bool):
if strategy_name == "ddp": if strategy_name == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini": elif strategy_name == "colossalai_gemini":
...@@ -60,11 +56,9 @@ def run_test_checkpoint(strategy_name: str, ...@@ -60,11 +56,9 @@ def run_test_checkpoint(strategy_name: str,
dist.broadcast_object_list(rank0_dirname) dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0] rank0_dirname = rank0_dirname[0]
model_path = os.path.join( model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard) strategy.save_model(actor, model_path, only_rank0=not shard)
optim_path = os.path.join( optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard) strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
dist.barrier() dist.barrier()
...@@ -75,11 +69,7 @@ def run_test_checkpoint(strategy_name: str, ...@@ -75,11 +69,7 @@ def run_test_checkpoint(strategy_name: str,
train_step(strategy, actor, actor_optim) train_step(strategy, actor, actor_optim)
def run_dist(rank: int, def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
world_size: int,
port: int,
strategy_name: str,
shard: bool):
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
...@@ -93,13 +83,8 @@ def run_dist(rank: int, ...@@ -93,13 +83,8 @@ def run_dist(rank: int,
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"]) @pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
@pytest.mark.parametrize("shard", [False, True]) @pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size: int, def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
strategy_name: str, spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
shard: bool):
spawn(run_dist,
world_size,
strategy_name=strategy_name,
shard=shard)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,62 +8,40 @@ import torch ...@@ -8,62 +8,40 @@ import torch
from coati.dataset.prompt_dataset import PromptDataset from coati.dataset.prompt_dataset import PromptDataset
from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [ SFT_DATASET = [
{ {
"instruction": "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
"Provide a list of the top 10 most popular mobile games in Asia", "input": "",
"input": "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"", "id": 0,
"output":
"The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id":
0
}, },
{ {
"instruction": "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
"Please provide an action plan for reducing carbon footprint on a corporate level", "input": "",
"input": "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"", "id": 1,
"output":
"An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id":
1
}, },
{ {
"instruction": "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
"Write a persuasive email to your boss explaining why you should have a pay raise", "input": "",
"input": "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"", "id": 2,
"output":
"Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id":
2
}, },
] ]
PROMPT_DATASET = [ PROMPT_DATASET = [
{ {
"instruction": "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."',
"Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", "id": 0,
"id":
0
},
{
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
"id": 1
},
{
"instruction": "Write a persuasive essay arguing why homework should be banned in schools",
"id": 2
},
{
"instruction": "Create a chart comparing the statistics on student debt in the United States.",
"id": 3
}, },
{"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1},
{"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2},
{"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3},
] ]
...@@ -120,10 +98,12 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): ...@@ -120,10 +98,12 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
json.dump(PROMPT_DATASET, f) json.dump(PROMPT_DATASET, f)
tokenizer = make_tokenizer(model) tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right") assert tokenizer.padding_side in ("left", "right")
prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name), prompt_dataset = PromptDataset(
tokenizer=tokenizer, data_path=os.path.join(tmp_dir, dataset_name),
max_datasets_size=max_datasets_size, tokenizer=tokenizer,
max_length=max_length) max_datasets_size=max_datasets_size,
max_length=max_length,
)
assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET)) assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
for i in range(len(prompt_dataset)): for i in range(len(prompt_dataset)):
assert isinstance(prompt_dataset[i], dict) assert isinstance(prompt_dataset[i], dict)
...@@ -137,14 +117,14 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): ...@@ -137,14 +117,14 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), @pytest.mark.parametrize(
("Dahoas/rm-static", None)]) ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)]
)
@pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset) data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \ assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"])
and max_datasets_size <= len(data["test"])
train_data = data["train"].select(range(max_datasets_size)) train_data = data["train"].select(range(max_datasets_size))
test_data = data["test"].select(range(max_datasets_size)) test_data = data["test"].select(range(max_datasets_size))
tokenizer = make_tokenizer(model) tokenizer = make_tokenizer(model)
...@@ -162,8 +142,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma ...@@ -162,8 +142,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert len(train_dataset) == len(test_dataset) == max_datasets_size assert len(train_dataset) == len(test_dataset) == max_datasets_size
for i in range(max_datasets_size): for i in range(max_datasets_size):
chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i] chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
assert chosen_ids.shape == c_mask.shape == \ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool) c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool) r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
...@@ -180,8 +159,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma ...@@ -180,8 +159,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask) assert torch.all(r_mask)
chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i] chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
assert chosen_ids.shape == c_mask.shape == \ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool) c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool) r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
...@@ -198,7 +176,6 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma ...@@ -198,7 +176,6 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask) assert torch.all(r_mask)
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_dataset_size", [2])
...@@ -214,10 +191,12 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: ...@@ -214,10 +191,12 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
dataset_name = "sft_dataset.json" dataset_name = "sft_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f: with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(SFT_DATASET, f) json.dump(SFT_DATASET, f)
sft_dataset = SupervisedDataset(tokenizer=tokenizer, sft_dataset = SupervisedDataset(
data_path=os.path.join(tmp_dir, dataset_name), tokenizer=tokenizer,
max_datasets_size=max_dataset_size, data_path=os.path.join(tmp_dir, dataset_name),
max_length=max_length) max_datasets_size=max_dataset_size,
max_length=max_length,
)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
if isinstance(tokenizer, ChatGLMTokenizer): if isinstance(tokenizer, ChatGLMTokenizer):
...@@ -227,20 +206,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: ...@@ -227,20 +206,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
input_ids = sft_dataset[i]["input_ids"] input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"] labels = sft_dataset[i]["labels"]
assert input_ids.shape == labels.shape == torch.Size([max_length]) assert input_ids.shape == labels.shape == torch.Size([max_length])
ignore_mask = labels == IGNORE_INDEX ignore_mask = labels == IGNORE_INDEX
assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model) check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
return return
for i in range(max_dataset_size): for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict) assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
input_ids = sft_dataset[i]["input_ids"] input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"] labels = sft_dataset[i]["labels"]
attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool) attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
assert input_ids.shape == labels.shape == \ assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length])
attention_mask.shape == torch.Size([max_length])
if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id: if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model) check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id) assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
...@@ -254,13 +232,8 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: ...@@ -254,13 +232,8 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
if __name__ == "__main__": if __name__ == "__main__":
test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
test_reward_dataset(model="gpt2", test_reward_dataset(
dataset_path="Anthropic/hh-rlhf", model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256
subset="harmless-base", )
max_datasets_size=8,
max_length=256)
test_prompt_dataset(model="opt",
max_datasets_size=2,
max_length=128)
test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
...@@ -18,7 +18,7 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) ...@@ -18,7 +18,7 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict: def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
...@@ -37,12 +37,12 @@ def make_and_consume_experience(strategy): ...@@ -37,12 +37,12 @@ def make_and_consume_experience(strategy):
EXPERIENCE_BATCH_SIZE = 4 EXPERIENCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2 SAMPLE_BATCH_SIZE = 2
if strategy == 'ddp': if strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif strategy == 'colossalai-zero2': elif strategy == "colossalai-zero2":
strategy = LowLevelZeroStrategy() strategy = LowLevelZeroStrategy()
elif strategy == 'colossalai-gemini': elif strategy == "colossalai-gemini":
strategy = GeminiStrategy(placement_policy='cuda') strategy = GeminiStrategy(placement_policy="cuda")
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f'Unsupported strategy "{strategy}"')
...@@ -58,13 +58,11 @@ def make_and_consume_experience(strategy): ...@@ -58,13 +58,11 @@ def make_and_consume_experience(strategy):
# experience of all ranks should be the same # experience of all ranks should be the same
for _ in range(2): for _ in range(2):
data = get_data(EXPERIENCE_BATCH_SIZE) data = get_data(EXPERIENCE_BATCH_SIZE)
assert gather_and_equal(data['input_ids']) assert gather_and_equal(data["input_ids"])
assert gather_and_equal(data['attention_mask']) assert gather_and_equal(data["attention_mask"])
experience = experience_maker.make_experience(**data, experience = experience_maker.make_experience(
do_sample=True, **data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
max_length=16, )
eos_token_id=50256,
pad_token_id=50256)
assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values) assert gather_and_equal(experience.values)
...@@ -75,7 +73,7 @@ def make_and_consume_experience(strategy): ...@@ -75,7 +73,7 @@ def make_and_consume_experience(strategy):
data_buffer.append(experience) data_buffer.append(experience)
# data buffer's data should be the same # data buffer's data should be the same
buffer_size = torch.tensor([len(data_buffer)], device='cuda') buffer_size = torch.tensor([len(data_buffer)], device="cuda")
assert gather_and_equal(buffer_size) assert gather_and_equal(buffer_size)
for item in data_buffer.items: for item in data_buffer.items:
assert gather_and_equal(item.sequences) assert gather_and_equal(item.sequences)
...@@ -88,7 +86,7 @@ def make_and_consume_experience(strategy): ...@@ -88,7 +86,7 @@ def make_and_consume_experience(strategy):
# dataloader of each rank should have the same size and different batch # dataloader of each rank should have the same size and different batch
dataloader = strategy.setup_dataloader(data_buffer) dataloader = strategy.setup_dataloader(data_buffer)
dataloader_size = torch.tensor([len(dataloader)], device='cuda') dataloader_size = torch.tensor([len(dataloader)], device="cuda")
assert gather_and_equal(dataloader_size) assert gather_and_equal(dataloader_size)
for experience in dataloader: for experience in dataloader:
assert not gather_and_equal(experience.sequences) assert not gather_and_equal(experience.sequences)
...@@ -100,21 +98,21 @@ def make_and_consume_experience(strategy): ...@@ -100,21 +98,21 @@ def make_and_consume_experience(strategy):
def run_dist(rank, world_size, port, strategy): def run_dist(rank, world_size, port, strategy):
os.environ['RANK'] = str(rank) os.environ["RANK"] = str(rank)
os.environ['LOCAL_RANK'] = str(rank) os.environ["LOCAL_RANK"] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost' os.environ["MASTER_ADDR"] = "localhost"
os.environ['MASTER_PORT'] = str(port) os.environ["MASTER_PORT"] = str(port)
make_and_consume_experience(strategy) make_and_consume_experience(strategy)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini']) @pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_experience(world_size, strategy): def test_experience(world_size, strategy):
spawn(run_dist, world_size, strategy=strategy) spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__': if __name__ == "__main__":
test_experience(2, 'colossalai') test_experience(2, "colossalai")
...@@ -6,15 +6,16 @@ import torch ...@@ -6,15 +6,16 @@ import torch
import torch.nn as nn import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.generation import generate from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor
from coati.models.chatglm import ChatGLMActor
from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32]) @pytest.mark.parametrize("seq_len", [32])
...@@ -23,19 +24,24 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer ...@@ -23,19 +24,24 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
[ [
lambda: BLOOMActor(), lambda: BLOOMActor(),
lambda: GPTActor(), lambda: GPTActor(),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: LlamaActor(), # lambda: LlamaActor(),
lambda: OPTActor(), lambda: OPTActor(),
# lambda: ChatGLMActor(), # lambda: ChatGLMActor(),
]) ],
)
@pytest.mark.parametrize("generate_kwargs", [{ @pytest.mark.parametrize(
"max_length": 64, "generate_kwargs",
"use_cache": True, [
"do_sample": True, {
"temperature": 1.0, "max_length": 64,
"top_k": 50, "use_cache": True,
}]) "do_sample": True,
"temperature": 1.0,
"top_k": 50,
}
],
)
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker() actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
...@@ -56,7 +62,7 @@ def test_utils(): ...@@ -56,7 +62,7 @@ def test_utils():
"kl_coef": 1.0, "kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)), "log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels)) "action_mask": torch.randint(0, 2, (batch_size, num_labels)),
} }
fn_output = compute_reward(**fn_input) fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size,) assert fn_output.shape == (batch_size,)
...@@ -66,9 +72,7 @@ def test_utils(): ...@@ -66,9 +72,7 @@ def test_utils():
num_labels = 10 num_labels = 10
num_actions = 2 num_actions = 2
fn_input = { fn_input = {
"output": { "output": {"logits": torch.randn((batch_size, seq_len, num_labels))},
"logits": torch.randn((batch_size, seq_len, num_labels))
},
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions, "num_actions": num_actions,
} }
...@@ -105,8 +109,9 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): ...@@ -105,8 +109,9 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
assert isinstance(lora_model[i], LoraLinear) assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias) assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, assert not torch.allclose(
lora_model[i].lora_B @ lora_model[i].lora_A) old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
)
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
...@@ -116,54 +121,60 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): ...@@ -116,54 +121,60 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
[ [
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()), lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()), lambda: (OPTActor(), OPTCritic(), OPTRM()),
lambda: (ChatGLMActor(), None, None), lambda: (ChatGLMActor(), None, None),
]) ],
)
@torch.no_grad() @torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
batch_size: int,
seq_len: int):
actor_input = { actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)), "input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)) "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
} }
critic_input = { critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)), "sequences": torch.randint(0, 100, (batch_size, seq_len)),
"action_mask": torch.randint(0, 2, (batch_size, seq_len)), "action_mask": torch.randint(0, 2, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)) "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
} }
rm_input = { rm_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)), "sequences": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)) "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
} }
actor, critic, rm = models_maker() actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor): if isinstance(actor, ChatGLMActor):
actor = actor.float() actor = actor.float()
tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1) chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
actor_input ={ actor_input = {
"input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1), "input_ids": torch.cat(
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)) (
} torch.randint(0, 100, (batch_size, seq_len // 2)),
chatglm_special_token,
torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
),
dim=1,
),
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
}
assert isinstance(actor, Actor) assert isinstance(actor, Actor)
base_actor_model = get_base_model(actor) get_base_model(actor)
actor_output = actor(**actor_input) actor_output = actor(**actor_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len) assert actor_output.logits.shape[:2] == (batch_size, seq_len)
if critic: if critic:
assert isinstance(critic, Critic) assert isinstance(critic, Critic)
base_critic_model = get_base_model(critic) get_base_model(critic)
critic_output = critic(**critic_input) critic_output = critic(**critic_input)
assert critic_output.shape == (batch_size, ) assert critic_output.shape == (batch_size,)
if rm: if rm:
assert isinstance(rm, RewardModel) assert isinstance(rm, RewardModel)
base_rm_model = get_base_model(rm) get_base_model(rm)
rm_output = rm(**rm_input) rm_output = rm(**rm_input)
assert rm_output.shape == (batch_size, ) assert rm_output.shape == (batch_size,)
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
...@@ -173,39 +184,59 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int): ...@@ -173,39 +184,59 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss() loss = GPTLMLoss()
loss_input = { loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels), "logits": torch.randn(batch_size, seq_len, num_labels),
"labels": torch.randint(0, num_labels, (batch_size, seq_len)) "labels": torch.randint(0, num_labels, (batch_size, seq_len)),
} }
loss_output = loss(**loss_input) loss(**loss_input)
loss = PolicyLoss() loss = PolicyLoss()
loss_input = { loss_input = {
"log_probs": torch.randn(batch_size,), "log_probs": torch.randn(
"old_log_probs": torch.randn(batch_size,), batch_size,
"advantages": torch.randn(batch_size,) ),
"old_log_probs": torch.randn(
batch_size,
),
"advantages": torch.randn(
batch_size,
),
} }
loss_output = loss(**loss_input) loss(**loss_input)
loss = ValueLoss() loss = ValueLoss()
loss_input = { loss_input = {
"values": torch.randn(batch_size,), "values": torch.randn(
"old_values": torch.randn(batch_size,), batch_size,
"reward": torch.randn(batch_size,) ),
"old_values": torch.randn(
batch_size,
),
"reward": torch.randn(
batch_size,
),
} }
loss_output = loss(**loss_input) loss(**loss_input)
loss = LogSigLoss() loss = LogSigLoss()
loss_input = { loss_input = {
"chosen_reward": torch.randn(batch_size,), "chosen_reward": torch.randn(
"reject_reward": torch.randn(batch_size,), batch_size,
),
"reject_reward": torch.randn(
batch_size,
),
} }
loss_output = loss(**loss_input) loss(**loss_input)
loss = LogExpLoss() loss = LogExpLoss()
loss_input = { loss_input = {
"chosen_reward": torch.randn(batch_size,), "chosen_reward": torch.randn(
"reject_reward": torch.randn(batch_size,), batch_size,
),
"reject_reward": torch.randn(
batch_size,
),
} }
loss_output = loss(**loss_input) loss(**loss_input)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -218,4 +249,4 @@ if __name__ == "__main__": ...@@ -218,4 +249,4 @@ if __name__ == "__main__":
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100) test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
...@@ -6,7 +6,7 @@ try: ...@@ -6,7 +6,7 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
# this will only happen if the user did not run `pip install` # this will only happen if the user did not run `pip install`
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice # and directly set PYTHONPATH to use Colossal-AI which is a bad practice
__version__ = '0.0.0' __version__ = "0.0.0"
print('please install Colossal-AI from https://www.colossalai.org/download or from source') print("please install Colossal-AI from https://www.colossalai.org/download or from source")
__all__ = ['launch', 'launch_from_openmpi', 'launch_from_slurm', 'launch_from_torch', '__version__'] __all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"]
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations # for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
from packaging import version from packaging import version
...@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like ...@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
def new(*args, **kwargs): def new(*args, **kwargs):
return orig_empty(*args, **kwargs, device=torch.device('meta')) return orig_empty(*args, **kwargs, device=torch.device("meta"))
def new_strided(*args, **kwargs): def new_strided(*args, **kwargs):
return orig_empty_strided(*args, **kwargs, device=torch.device('meta')) return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
def new_like(*args, **kwargs): def new_like(*args, **kwargs):
return orig_empty_like(*args, **kwargs, device=torch.device('meta')) return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
def register_meta(op, register_dispatcher=True): def register_meta(op, register_dispatcher=True):
def wrapper(f): def wrapper(f):
def add_func(op): def add_func(op):
meta_table[op] = f meta_table[op] = f
if register_dispatcher: if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try: try:
meta_lib.impl(name, f) meta_lib.impl(name, f)
except: except:
...@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True): ...@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper return wrapper
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse("1.12.0"):
# ============================== Convolutions ====================================== # ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834 # https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default) @register_meta(aten.convolution.default)
...@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
output_padding: List[int], output_padding: List[int],
groups: int, groups: int,
): ):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
""" """
Formula to apply to calculate the length of some dimension of the output Formula to apply to calculate the length of some dimension of the output
...@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
kernel_size[i], kernel_size[i],
stride[i], stride[i],
output_padding_list[i], output_padding_list[i],
)) )
)
else: else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape return ret_shape
...@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format() mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out return out
@register_meta(aten._convolution.default) @register_meta(aten._convolution.default)
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], def meta__conv(
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, input_tensor: torch.Tensor,
*extra_args): weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args,
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out return out
@register_meta(aten.convolution_backward.default) @register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, def meta_conv_backward(
padding, dilation, transposed, output_padding, groups, output_mask): grad_output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
return new_like(input), new_like(weight), new((bias_sizes)) return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
...@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
batch_sizes, batch_sizes,
dropout_state, dropout_state,
): ):
is_input_packed = len(batch_sizes) != 0 is_input_packed = len(batch_sizes) != 0
if is_input_packed: if is_input_packed:
seq_length = len(batch_sizes) seq_length = len(batch_sizes)
...@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
if is_input_packed: if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions] out_shape = [batch_sizes_sum, out_size * num_directions]
else: else:
out_shape = ([mini_batch, seq_length, out_size * out_shape = (
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) [mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape) output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size] cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
...@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default) @register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor, def meta_cudnn_rnn_backward(
weight: torch.Tensor, input: torch.Tensor,
weight_stride0: int, weight: torch.Tensor,
hx: torch.Tensor, weight_stride0: int,
cx: Optional[torch.Tensor] = None, hx: torch.Tensor,
*args, cx: Optional[torch.Tensor] = None,
**kwargs): *args,
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( **kwargs,
()) # (grad_input, grad_weight, grad_hx, grad_cx) ):
return (
new_like(input),
new_like(weight),
new_like(hx),
new_like(cx) if cx is not None else new(()),
) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations ======================================= # ============================== Activations =======================================
...@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.hardtanh_backward.default, aten.hardtanh_backward.default,
] ]
if version.parse(torch.__version__) < version.parse('2.0.0'): if version.parse(torch.__version__) < version.parse("2.0.0"):
_unregistered_ewise += [ _unregistered_ewise += [
aten.prelu_backward.default, aten.prelu_backward.default,
] ]
...@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default) @register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, def meta_bn_backward(
save_mean, save_invstd, train, eps, output_mask): dY: torch.Tensor,
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default) @register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1) n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input)), new( return (
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) new_like(input),
new((n_input)),
new((n_input)),
new((0), dtype=torch.uint8),
) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm # NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm), # in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter. # which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default) @register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, def meta_cudnn_bn_backward(
save_mean, save_invstd, eps, reserve): dY: torch.Tensor,
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
eps,
reserve,
):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default) @register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1) bs, n_input = input.size(0), input.size(1)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default) @register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, def meta_ln_backward(
grad_input_mask): dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) ):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ========================================== # ================================== Misc ==========================================
# Maybe incorrect # Maybe incorrect
...@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default) @register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, def meta_embedding_dense_backward(
scale_grad_by_freq): grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout) return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout =========================================== # ============================== Dropout ===========================================
...@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.native_dropout.default) @register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool # notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default) @register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in) return new_like(grad) # (grad_in)
if version.parse(torch.__version__) < version.parse('1.13.0'): if version.parse(torch.__version__) < version.parse("1.13.0"):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out) @register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor): def meta_eye(n: int, m: int, out: torch.Tensor):
...@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
result: List[Optional[torch.Tensor]] = [] result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices): for i, index in enumerate(indices):
if index is not None: if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\ assert index.dtype in [
"tensors used as indices must be long, byte or bool tensors" torch.long,
torch.int8,
torch.bool,
], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]: if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero() nonzero = index.nonzero()
k = len(result) k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim): for j in range(index.ndim):
assert index.shape[j] == self.shape[ assert (
k + index.shape[j] == self.shape[k + j]
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j)) result.append(nonzero.select(1, j))
else: else:
result.append(index) result.append(index)
else: else:
result.append(index) result.append(index)
indices = result indices = result
assert len( assert (
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" len(indices) <= self.ndim
), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace # expand_outplace
import torch._refs as refs import torch._refs as refs
......
import torch import torch
import torch.distributed as dist
from packaging import version from packaging import version
__all__ = [ __all__ = [
...@@ -48,7 +47,7 @@ _DistCommMethod = [ ...@@ -48,7 +47,7 @@ _DistCommMethod = [
"scatter", "scatter",
] ]
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse("1.12.0"):
aten = torch.ops.aten aten = torch.ops.aten
# TODO: dive deep here # TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
......
...@@ -8,7 +8,7 @@ from contextlib import contextmanager ...@@ -8,7 +8,7 @@ from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
from functools import partial, reduce from functools import partial, reduce
from numbers import Number from numbers import Number
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Union
import torch import torch
from packaging import version from packaging import version
...@@ -36,15 +36,15 @@ def _format_flops(flop): ...@@ -36,15 +36,15 @@ def _format_flops(flop):
B = 1e9 B = 1e9
T = 1e12 T = 1e12
if flop < K: if flop < K:
return f'{flop:.2f}' return f"{flop:.2f}"
elif flop < M: elif flop < M:
return f'{flop / K:.2f}K' return f"{flop / K:.2f}K"
elif flop < B: elif flop < B:
return f'{flop / M:.2f}M' return f"{flop / M:.2f}M"
elif flop < T: elif flop < T:
return f'{flop / B:.2f}B' return f"{flop / B:.2f}B"
else: else:
return f'{flop / T:.2f}T' return f"{flop / T:.2f}T"
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number: def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
...@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns: Returns:
Number: The total number of floating point operations (FWD + BWD). Number: The total number of floating point operations (FWD + BWD).
""" """
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False) maybe_inplace = (
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_')) getattr(module, "inplace", False)
or kwargs.get("inplace", False)
or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
)
class DummyModule(torch.nn.Module): class DummyModule(torch.nn.Module):
def __init__(self, func): def __init__(self, func):
super().__init__() super().__init__()
self.func = func self.func = func
...@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0} total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int)) flop_counts = defaultdict(lambda: defaultdict(int))
parents = ['Global'] parents = ["Global"]
module = module if isinstance(module, torch.nn.Module) else DummyModule(module) module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor): class FlopTensor(MetaTensor):
_tensor: torch.Tensor _tensor: torch.Tensor
def __repr__(self): def __repr__(self):
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor' name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
if self.grad_fn: if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# no_dispatch is only needed if you use enable_python_mode. # no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion. # It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs) rs = super().__torch_dispatch__(func, types, args, kwargs)
...@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return isinstance(x, torch.Tensor) and x.is_floating_point() return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name): def create_backwards_push(name):
class PushState(torch.autograd.Function): class PushState(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
...@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return PushState.apply return PushState.apply
def create_backwards_pop(name): def create_backwards_pop(name):
class PopState(torch.autograd.Function): class PopState(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
...@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
@staticmethod @staticmethod
def backward(ctx, *grad_outs): def backward(ctx, *grad_outs):
nonlocal parents nonlocal parents
assert (parents[-1] == name) assert parents[-1] == name
parents.pop() parents.pop()
return grad_outs return grad_outs
return PopState.apply return PopState.apply
def enter_module(name): def enter_module(name):
def f(module, inputs): def f(module, inputs):
nonlocal parents nonlocal parents
parents.append(name) parents.append(name)
...@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return f return f
def exit_module(name): def exit_module(name):
def f(module, inputs, outputs): def f(module, inputs, outputs):
nonlocal parents nonlocal parents
assert (parents[-1] == name) assert parents[-1] == name
parents.pop() parents.pop()
outputs = normalize_tuple(outputs) outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs) return create_backwards_push(name)(*outputs)
...@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
for mod in flop_counts.keys(): for mod in flop_counts.keys():
print(f"Module: ", mod) print(f"Module: ", mod)
for k, v in flop_counts[mod].items(): for k, v in flop_counts[mod].items():
print('\t', k, _format_flops(v)) print("\t", k, _format_flops(v))
print() print()
def detach_variables(r): def detach_variables(r):
...@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: ...@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
def wrap(r): def wrap(r):
if isinstance(r, torch.Tensor): if isinstance(r, torch.Tensor):
data_ptr_fn = getattr(r, '_tensor', r).data_ptr data_ptr_fn = getattr(r, "_tensor", r).data_ptr
r = FlopTensor(detach_variables(r)) r = FlopTensor(detach_variables(r))
if maybe_inplace: if maybe_inplace:
r = r + 0 r = r + 0
...@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: ...@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input. # Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], has_affine = (
'shape') else inputs[affine_arg_index] inputs[affine_arg_index].shape is not None
if hasattr(inputs[affine_arg_index], "shape")
else inputs[affine_arg_index]
)
assert 2 <= len(input_shape) <= 5, input_shape assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate # 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
...@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N ...@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3] training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training: if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape) input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1) return input_shape * (2 if has_affine else 1)
...@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla ...@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla
def zero_flop_jit(*args): def zero_flop_jit(*args):
""" """
Count flops for zero flop layers. Count flops for zero flop layers.
""" """
return 0 return 0
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse("1.12.0"):
flop_mapping = { flop_mapping = {
# gemm # gemm
aten.mm.default: matmul_flop_jit, aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit, aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit, aten.bmm.default: bmm_flop_jit,
# convolution
# convolution
aten.convolution.default: conv_flop_jit, aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
# pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0), aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0), aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
...@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
} }
ewise_flop_aten = [ ewise_flop_aten = [
# basic op # basic op
aten.add.Tensor, aten.add.Tensor,
aten.add_.Tensor, aten.add_.Tensor,
aten.div.Tensor, aten.div.Tensor,
...@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.sum.default, aten.sum.default,
aten.sum.dim_IntList, aten.sum.dim_IntList,
aten.mean.dim, aten.mean.dim,
# activation op
# activation op
aten.hardswish.default, aten.hardswish.default,
aten.hardswish_.default, aten.hardswish_.default,
aten.hardswish_backward.default, aten.hardswish_backward.default,
...@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.tanh.default, aten.tanh.default,
aten.tanh_backward.default, aten.tanh_backward.default,
aten.threshold_backward.default, aten.threshold_backward.default,
# dropout
# dropout
aten.native_dropout.default, aten.native_dropout.default,
aten.native_dropout_backward.default, aten.native_dropout_backward.default,
# distribution
# distribution
aten.bernoulli_.float, aten.bernoulli_.float,
# where
# where
aten.where.self, aten.where.self,
] ]
for op in ewise_flop_aten: for op in ewise_flop_aten:
......
...@@ -3,12 +3,12 @@ from functools import partial ...@@ -3,12 +3,12 @@ from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.types import _bool, _device, _dtype from torch.types import _device
from torch.utils._pytree import tree_flatten, tree_map from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
__all__ = ['MetaTensor', 'MetaTensorMode'] __all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None): def register_storage(r, data_ptr_fn=None):
...@@ -28,8 +28,7 @@ def _normalize_tuple(x): ...@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch # a hack of inplace execution in PyTorch
def _assert_alias(func): def _assert_alias(func):
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
)
class MetaTensor(torch.Tensor): class MetaTensor(torch.Tensor):
...@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor): ...@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, dtype=elem.dtype,
layout=elem.layout, layout=elem.layout,
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=requires_grad) # deceive the frontend for aten selections requires_grad=requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem r._tensor = elem
# ...the real tensor is held as an element on the tensor. # ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta: if not r._tensor.is_meta:
val = elem.data_ptr() val = elem.data_ptr()
data_ptr_fn = lambda: val data_ptr_fn = lambda: val
r._tensor = r._tensor.to(torch.device('meta')) r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta` # only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn) register_storage(r._tensor, data_ptr_fn)
...@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor): ...@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
return r return r
def __repr__(self): def __repr__(self):
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor' name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn: if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
...@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor): ...@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor x = x._tensor
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
device = x.device device = x.device
x = x.to(torch.device('meta')) x = x.to(torch.device("meta"))
return x return x
args = tree_map(unwrap, args) args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs) kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs: if "device" in kwargs:
device = kwargs['device'] device = kwargs["device"]
kwargs['device'] = torch.device('meta') kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta # run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy # here we detect whether or not the execution generates a physical copy
...@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor): ...@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
nonlocal device nonlocal device
if isinstance(x, str) or isinstance(x, _device): if isinstance(x, str) or isinstance(x, _device):
device = x device = x
return torch.device('meta') return torch.device("meta")
return x return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device) return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs): def cpu(self, *args, **kwargs):
if self.device.type == 'cpu': if self.device.type == "cpu":
return self.to(*args, **kwargs) return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs) return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False): def cuda(self, device=None, non_blocking=False):
if device is not None: if device is not None:
return self.to(device=device, non_blocking=non_blocking) return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking) return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self): def data_ptr(self):
return self._tensor.data_ptr() return self._tensor.data_ptr()
...@@ -177,19 +177,17 @@ class MetaTensorMode(object): ...@@ -177,19 +177,17 @@ class MetaTensorMode(object):
""" """
def __init__(self): def __init__(self):
self.torch_overrides = {} # override torch.xxx self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self): def __enter__(self):
def _dummy(*args, **kwargs): def _dummy(*args, **kwargs):
pass pass
def _new(*args, orig_new=torch.empty, **kwargs): def _new(*args, orig_new=torch.empty, **kwargs):
return MetaTensor(orig_new(*args, **{ return MetaTensor(
**kwargs, 'device': 'meta' orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
}), )
device=kwargs.get('device', torch.device('cpu')))
for func in _TorchOverrideableFactoryMethod: for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func) self.torch_overrides[func] = getattr(torch, func)
......
from typing import Any, Callable, Dict, Iterable, List, Tuple from typing import Any, Dict, List, Tuple
import torch import torch
...@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a ...@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import colossalai import colossalai
from colossalai.fx._compatibility import compatibility from colossalai.fx._compatibility import compatibility
_register_custom_builtin('colossalai', 'import colossalai', colossalai) _register_custom_builtin("colossalai", "import colossalai", colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
...@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True): ...@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
""" """
Generate the checkpoint function call code text Generate the checkpoint function call code text
""" """
outputs = ', '.join(output_vars) outputs = ", ".join(output_vars)
inputs = ', '.join(input_vars) inputs = ", ".join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})' return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
""" """
Check if the node could end the ckpt region at `ckpt_level` Check if the node could end the ckpt region at `ckpt_level`
""" """
if len(node.meta['info'].activation_checkpoint) > ckpt_level: if len(node.meta["info"].activation_checkpoint) > ckpt_level:
return node.meta['info'].activation_checkpoint[ckpt_level] is not None return node.meta["info"].activation_checkpoint[ckpt_level] is not None
return True return True
...@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): ...@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None current_region = None
for idx, node in enumerate(node_list): for idx, node in enumerate(node_list):
if len(node.meta['info'].activation_checkpoint) > ckpt_level: if len(node.meta["info"].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet # this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region # meaning this is the first node of the activation ckpt region
...@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): ...@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
return ckpt_regions return ckpt_regions
def emit_ckpt_func(body, def emit_ckpt_func(
ckpt_func, body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
node_list: List[Node], ):
emit_node_func,
delete_unused_value_func,
ckpt_level=0,
in_ckpt=False):
"""Emit ckpt function in nested way """Emit ckpt function in nested way
Args: Args:
...@@ -156,12 +152,12 @@ def emit_ckpt_func(body, ...@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1) # label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1' # the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n') ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch # if there is more level to fetch
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions] start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions]
...@@ -174,33 +170,40 @@ def emit_ckpt_func(body, ...@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
break break
if node_idx in start_idx: if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func, emit_ckpt_func(
ckpt_level + 1, True) ckpt_func,
ckpt_func_buffer,
ckpt_node_list,
emit_node_func,
delete_unused_value_func,
ckpt_level + 1,
True,
)
node_idx += len(ckpt_node_list) node_idx += len(ckpt_node_list)
else: else:
node = node_list[node_idx] node = node_list[node_idx]
emit_node_func(node, ckpt_func) emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1] ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func) delete_unused_value_func(node, ckpt_func)
node_idx += 1 node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer ckpt_func += ckpt_func_buffer
# last level # last level
else: else:
for node in node_list: for node in node_list:
emit_node_func(node, ckpt_func) emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1] ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func) delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n' usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
if in_ckpt: if in_ckpt:
usage = ' ' + usage usage = " " + usage
body.append(usage) body.append(usage)
...@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# process ckpt_regions # process ckpt_regions
if node_idx in start_idx: if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list) node_idx += len(ckpt_node_list)
...@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen): class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
...@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen):
wrapped_fns: Dict[str, None] = {} wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference # Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [''] maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any): def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global. """Add an obj to be tracked as a global.
...@@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen):
Graph, like functions or types. Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source. Returns: the global name that should be used to reference 'obj' in generated source.
""" """
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We # HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their # can't import them like normal modules so they must retain their
# fully qualified name. # fully qualified name.
...@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen):
def type_repr(o: Any): def type_repr(o: Any):
if o == (): if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()] # Empty tuple is used for empty tuple type annotation Tuple[()]
return '()' return "()"
typename = _type_repr(o) typename = _type_repr(o)
if hasattr(o, '__origin__'): if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor] # This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type) origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'): if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables. # Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__] args = [type_repr(arg) for arg in o.__args__]
...@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen):
return add_global(typename, o) return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg): def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global. # Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'): if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg)) qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg)) global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}" return f"{global_name}{repr(tuple(arg))}"
return repr(arg) return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args) args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s: if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}' return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use # Run through reverse nodes and record the first instance of a use
...@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen):
not used in the remainder of the code are freed and the memory usage not used in the remainder of the code are freed and the memory usage
of the code is optimal. of the code is optimal.
""" """
if user.op == 'placeholder': if user.op == "placeholder":
return return
if user.op == 'output': if user.op == "output":
body.append('\n') body.append("\n")
return return
nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete): if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f'; {to_delete_str}\n') body.append(f"; {to_delete_str}\n")
else: else:
body.append('\n') body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body): def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == 'placeholder': if node.op == "placeholder":
assert isinstance(node.target, str) assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace('*', '') raw_name = node.target.replace("*", "")
if raw_name != repr(node): if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n') body.append(f"{repr(node)} = {raw_name}\n")
return return
elif node.op == 'call_method': elif node.op == "call_method":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' body.append(
f'({_format_args(node.args[1:], node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return return
elif node.op == 'call_function': elif node.op == "call_function":
assert callable(node.target) assert callable(node.target)
# pretty print operators # pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple) assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return return
# pretty print inplace operators; required for jit.script to work properly # pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo # not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' body.append(
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return return
qualified_name = _get_qualified_name(node.target) qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target) global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument # special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \ if (
isinstance(node.args, tuple) and \ global_name == "getattr"
isinstance(node.args[1], str) and \ and isinstance(node.args, tuple)
node.args[1].isidentifier() and \ and isinstance(node.args[1], str)
len(node.args) == 2: and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return return
body.append( body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
if node.meta.get('is_wrapped', False): )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name) wrapped_fns.setdefault(global_name)
return return
elif node.op == 'call_module': elif node.op == "call_module":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return return
elif node.op == 'get_attr': elif node.op == "get_attr":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return return
elif node.op == 'output': elif node.op == "output":
if node.type is not None: if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}" maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0])) body.append(self.generate_output(node.args[0]))
return return
raise NotImplementedError(f'node: {node.op} {node.target}') raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing # Modified for activation checkpointing
ckpt_func = [] ckpt_func = []
...@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen):
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a # have been emitted. To continue to have valid Python code, emit a
# single pass statement # single pass statement
body.append('pass\n') body.append("pass\n")
if len(wrapped_fns) > 0: if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap) wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else: else:
wrap_stmts = '' wrap_stmts = ""
if self._body_transformer: if self._body_transformer:
body = self._body_transformer(body) body = self._body_transformer(body)
...@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen):
add_global(name, value) add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue prologue = "".join(ckpt_func) + prologue
prologue = prologue prologue = prologue
code = ''.join(body) code = "".join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f""" fn_code = f"""
{wrap_stmts} {wrap_stmts}
{prologue} {prologue}
......
...@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode ...@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
try: try:
from torch.fx.graph import _PyTreeCodeGen from torch.fx.graph import _PyTreeCodeGen
SUPPORT_PT_CODEGEN = True SUPPORT_PT_CODEGEN = True
except ImportError: except ImportError:
SUPPORT_PT_CODEGEN = False SUPPORT_PT_CODEGEN = False
...@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent ...@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall. # This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0. # It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall: class _WrappedCall:
def __init__(self, cls, cls_call): def __init__(self, cls, cls_call):
self.cls = cls self.cls = cls
self.cls_call = cls_call self.cls_call = cls_call
...@@ -50,12 +50,14 @@ class _WrappedCall: ...@@ -50,12 +50,14 @@ class _WrappedCall:
# constituent substrings of the error message # constituent substrings of the error message
tb_repr = traceback.format_exc() tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, " custom_msg = (
f"line {err_lineno} of the traced Module's " "Call using an FX-traced Module, "
"generated forward function:") f"line {err_lineno} of the traced Module's "
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno]) "generated forward function:"
)
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE" marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2]) err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message # joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
...@@ -65,11 +67,14 @@ class _WrappedCall: ...@@ -65,11 +67,14 @@ class _WrappedCall:
if self.cls_call is not None: if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs) return self.cls_call(obj, *args, **kwargs)
else: else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e: except Exception as e:
assert e.__traceback__ assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \ topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] traceback.walk_tb(e.__traceback__)
)[
-1
] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename: if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr) print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None) raise e.with_traceback(None)
...@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule): ...@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code. code.
""" """
def __init__(self, def __init__(
root: Union[torch.nn.Module, Dict[str, Any]], self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
graph: torch.fx.Graph, ):
class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name) super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals): def bind(self, ckpt_def, globals):
...@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule): ...@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen): if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self') python_code = self._graph.python_code(root_module="self")
self._code = python_code.src self._code = python_code.src
# To split ckpt functions code and forward code # To split ckpt functions code and forward code
...@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule): ...@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls): if "_wrapped_call" not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs): def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs) return self._wrapped_call(self, *args, **kwargs)
...@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule): ...@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
""" """
folder = Path(folder) folder = Path(folder)
Path(folder).mkdir(exist_ok=True) Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt') torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4 tab = " " * 4
# we add import colossalai here # we add import colossalai here
...@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module): ...@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children(): for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module) module_str = _gen_model_repr(module_name, module)
if module_str is None: if module_str is None:
module_file = folder / f'{module_name}.pt' module_file = folder / f"{module_name}.pt"
torch.save(module, module_file) torch.save(module, module_file)
blobified_modules.append(module_name) blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}" module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n" model_str += f"{tab*2}self.{module_name} = {module_str}\n"
...@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module): ...@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n" model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py' module_file = folder / "module.py"
module_file.write_text(model_str) module_file.write_text(model_str)
init_file = folder / '__init__.py' init_file = folder / "__init__.py"
init_file.write_text('from .module import *') init_file.write_text("from .module import *")
if len(blobified_modules) > 0: if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -" warnings.warn(
f"saved as pickled files instead: {blobified_modules}") "Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}"
)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch.autograd.profiler_util import _format_memory, _format_time from torch.autograd.profiler_util import _format_memory
from torch.fx import Graph, GraphModule, Node from torch.fx import Node
from colossalai._analyzer.envs import MeshConfig from colossalai._analyzer.envs import MeshConfig
...@@ -85,12 +85,12 @@ class MetaInfo: ...@@ -85,12 +85,12 @@ class MetaInfo:
node: Node node: Node
# directory # directory
mod_dir: str = '' mod_dir: str = ""
# ctx[data_ptr] = Tensor # ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward # mark the storage for ctx.save_for_backward
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation # should be updated after each graph manipulation
# ============================== Update ==================================== # ============================== Update ====================================
...@@ -100,7 +100,7 @@ class MetaInfo: ...@@ -100,7 +100,7 @@ class MetaInfo:
inputs: Tuple[torch.Tensor] = () inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = () outputs: Tuple[torch.Tensor] = ()
is_alias: Tuple[bool] = () # whether the output is an alias of input is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost # compute cost
fwd_flop: Optional[int] = 0 fwd_flop: Optional[int] = 0
...@@ -112,29 +112,29 @@ class MetaInfo: ...@@ -112,29 +112,29 @@ class MetaInfo:
# should keep the same whenever manipulated # should keep the same whenever manipulated
# ============================= Invariant ================================== # ============================= Invariant ==================================
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False to_offload: Optional[bool] = False
sharding_spec: str = 'RR' sharding_spec: str = "RR"
def __new__(cls, node: Node, **kwargs): def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__ orig_init = cls.__init__
# if initialized, return the existing one # if initialized, return the existing one
# should disable the __init__ function # should disable the __init__ function
if node.meta.get('info', None) is not None: if node.meta.get("info", None) is not None:
def _dummy(self, *args, **kwargs): def _dummy(self, *args, **kwargs):
if getattr(self, '_is_init', False): if getattr(self, "_is_init", False):
self._is_init = True self._is_init = True
orig_init(self, *args, **kwargs) orig_init(self, *args, **kwargs)
cls.__init__ = orig_init cls.__init__ = orig_init
cls.__init__ = _dummy cls.__init__ = _dummy
return node.meta['info'] return node.meta["info"]
return super().__new__(cls) return super().__new__(cls)
def __post_init__(self): def __post_init__(self):
self.node.meta['info'] = self self.node.meta["info"] = self
@property @property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH): def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
...@@ -188,24 +188,26 @@ class MetaInfo: ...@@ -188,24 +188,26 @@ class MetaInfo:
return compute_size_in_bytes(self.inputs) return compute_size_in_bytes(self.inputs)
def __repr__(self): def __repr__(self):
s = f'Node {self.node.name}' s = f"Node {self.node.name}"
if self.parameters: if self.parameters:
s += f'\n\thas parameter of size {_format_memory(self.param_size)}' s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
if self.buffers: if self.buffers:
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}' s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
if self.output_size: if self.output_size:
s += f'\n\thas output activation of size {_format_memory(self.output_size)}' s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
# if self.total_size: # if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}' # s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size: if self.temp_size:
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}' s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
if self.backward_size: if self.backward_size:
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}' s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
s += f'\n\tfwd_flop = {self.fwd_flop}'\ s += (
f'\n\tbwd_flop = {self.bwd_flop}'\ f"\n\tfwd_flop = {self.fwd_flop}"
f'\n\tfwd_comm = {self.fwd_comm}'\ f"\n\tbwd_flop = {self.bwd_flop}"
f'\n\tbwd_comm = {self.bwd_comm}'\ f"\n\tfwd_comm = {self.fwd_comm}"
f'\n\tto_recompute = {self.to_recompute}'\ f"\n\tbwd_comm = {self.bwd_comm}"
f'\n\tto_offload = {self.to_offload}'\ f"\n\tto_recompute = {self.to_recompute}"
f'\n\tsharding_spec = {self.sharding_spec}' f"\n\tto_offload = {self.to_offload}"
f"\n\tsharding_spec = {self.sharding_spec}"
)
return s return s
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch import torch
import torch.fx import torch.fx
from torch.autograd.profiler_util import _format_memory, _format_time from torch.autograd.profiler_util import _format_memory
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target from torch.fx.node import Argument, Node, Target
...@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo ...@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
def _format_flops(flops: float) -> str: def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string""" """Returns a formatted FLOP size string"""
if flops > 1e12: if flops > 1e12:
return f'{flops / 1e12:.2f} TFLOPs' return f"{flops / 1e12:.2f} TFLOPs"
elif flops > 1e9: elif flops > 1e9:
return f'{flops / 1e9:.2f} GFLOPs' return f"{flops / 1e9:.2f} GFLOPs"
elif flops > 1e6: elif flops > 1e6:
return f'{flops / 1e6:.2f} MFLOPs' return f"{flops / 1e6:.2f} MFLOPs"
elif flops > 1e3: elif flops > 1e3:
return f'{flops / 1e3:.2f} kFLOPs' return f"{flops / 1e3:.2f} kFLOPs"
return f'{flops} FLOPs' return f"{flops} FLOPs"
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]: def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
...@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter): ...@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch. the ``GraphModule`` from scratch.
""" """
_profileable = [ _profileable = [
'call_function', "call_function",
'call_module', "call_module",
'call_method', "call_method",
] ]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True): def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
...@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter): ...@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
self.args_iter: Iterator[Any] = iter(args) self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes: for node in self.module.graph.nodes:
self.run_node(node) # No need to store.
self.run_node(node) # No need to store.
if self.garbage_collect_values: if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []): for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete] del self.env[to_delete]
if node.op == 'output': if node.op == "output":
output_val = self.env[node] output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
...@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter): ...@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
try: try:
from tabulate import tabulate from tabulate import tabulate
except ImportError: except ImportError:
print("`summary` relies on the library `tabulate`, " print(
"which could not be found on this machine. Run `pip " "`summary` relies on the library `tabulate`, "
"install tabulate` to install the library.") "which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
# Build up a list of summary information for each node # Build up a list of summary information for each node
node_summaries: List[List[Any]] = [] node_summaries: List[List[Any]] = []
...@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter): ...@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
node: Node node: Node
n_info = MetaInfo(node) n_info = MetaInfo(node)
last_n_info = last_n_info or n_info last_n_info = last_n_info or n_info
node_summaries.append([ node_summaries.append(
node.op, [
str(node), node.op,
_format_memory(n_info.accumulate_size), str(node),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size), _format_memory(n_info.accumulate_size),
_format_memory(n_info.output_size), _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.temp_size), _format_memory(n_info.output_size),
_format_memory(n_info.param_size), _format_memory(n_info.temp_size),
_format_memory(n_info.backward_size), _format_memory(n_info.param_size),
_format_flops(n_info.fwd_flop), _format_memory(n_info.backward_size),
_format_flops(n_info.bwd_flop), _format_flops(n_info.fwd_flop),
]) _format_flops(n_info.bwd_flop),
]
)
last_n_info = n_info last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table # Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information # presenting our summary information
headers: List[str] = [ headers: List[str] = [
'Op type', "Op type",
'Op', "Op",
'Accumulate size', "Accumulate size",
'Incremental size', "Incremental size",
'Output size', "Output size",
'Temp size', "Temp size",
'Param size', "Param size",
'Backward size', "Backward size",
'Fwd FLOPs', "Fwd FLOPs",
'Bwd FLOPs', "Bwd FLOPs",
] ]
return tabulate(node_summaries, headers=headers, stralign='right') return tabulate(node_summaries, headers=headers, stralign="right")
class CommunicationProfiler(GraphProfiler): class CommunicationProfiler(GraphProfiler):
...@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler): ...@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
>>> def my_fn_flop_count_impl(*args, **kwargs): >>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0 >>> return 0, 0
""" """
_custom_flop_count_impl = {} _custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any: def run_node(self, n: torch.fx.Node) -> Any:
...@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler): ...@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
( (
n_info.fwd_flop, n_info.fwd_flop,
n_info.bwd_flop, n_info.bwd_flop,
) = getattr(self, n.op)(n.target, args, kwargs) ) = getattr(
self, n.op
)(n.target, args, kwargs)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. ' f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!' f"Please refer to function's docstring to register the relevant profile_impl for this node!"
) from e ) from e
# retain the autograd graph # retain the autograd graph
...@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler): ...@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
return _denormalize_tuple(n_info.outputs) return _denormalize_tuple(n_info.outputs)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_function`` node and return the profiling result. Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
...@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler): ...@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
else: else:
return flop_count(target, *args, **kwargs) return flop_count(target, *args, **kwargs)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_method`` node and return the profiling result. Execute a ``call_method`` node and return the profiling result.
...@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler): ...@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
assert isinstance(target, str) assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs) return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_module`` node and return the profiling result. Execute a ``call_module`` node and return the profiling result.
...@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule ...@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns: Returns:
GraphModule: The same GraphModule with profiling information GraphModule: The same GraphModule with profiling information
""" """
for profiler_cls in (FlopProfiler, for profiler_cls in (
# CommunicationProfiler, # TODO: add communication profiling FlopProfiler,
): # CommunicationProfiler, # TODO: add communication profiling
):
profiler = profiler_cls(module) profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module)) profiler.propagate(*args, device=_current_device(module))
......
...@@ -54,7 +54,7 @@ def _current_device(module): ...@@ -54,7 +54,7 @@ def _current_device(module):
try: try:
return next(module.parameters()).device return next(module.parameters()).device
except StopIteration: except StopIteration:
return torch.device('cpu') return torch.device("cpu")
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
...@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here >>> # do something here
>>> return torch.empty(output_shape, device=output_device) >>> return torch.empty(output_shape, device=output_device)
""" """
_custom_dispatch_func = {} _custom_dispatch_func = {}
_mode = MetaTensorMode() _mode = MetaTensorMode()
...@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
r = getattr(self, n.op)(n.target, args, kwargs) r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem): def unwrap_fn(elem):
def _convert_meta(t: torch.Tensor): def _convert_meta(t: torch.Tensor):
if t.device == 'meta': if t.device == "meta":
return t return t
else: else:
return t.to('meta') return t.to("meta")
if isinstance(elem, MetaTensor): if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False): if getattr(self, "_is_param", False):
return torch.nn.Parameter(_convert_meta(elem._tensor)) return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor) return _convert_meta(elem._tensor)
...@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
n_info = MetaInfo(n) n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r) n_info.outputs = _normalize_tuple(r)
if n.op == 'call_module': if n.op == "call_module":
submod = self.fetch_attr(n.target) submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()}) n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()}) n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else: else:
n_info.parameters.update({ n_info.parameters.update(
k.name: MetaTensor(v) {
for k, v in zip(n.args, args) k.name: MetaTensor(v)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) for k, v in zip(n.args, args)
}) if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
}
)
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)}) n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
tuple(v for v in kwargs.values() if is_pure_tensor(v)) v for v in kwargs.values() if is_pure_tensor(v)
)
# align with SPMD # align with SPMD
if isinstance(r, (tuple, list)): if isinstance(r, (tuple, list)):
...@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs)) n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r return r
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_function`` node and return the result. Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``, If the target of ``Node`` is registered with ``@register_shape_impl``,
...@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
else: else:
return res return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
""" """
Execute a ``call_method`` node and return the result. Execute a ``call_method`` node and return the result.
...@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
convert_to_parameter = False convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance( if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter): args[0], torch.nn.parameter.Parameter
):
convert_to_parameter = True convert_to_parameter = True
# Execute the method and return the result # Execute the method and return the result
assert isinstance(target, str) assert isinstance(target, str)
......
import torch
import torch.fx
from torch.fx import GraphModule from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
...@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler ...@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler
def register_flop_count_impl(func): def register_flop_count_impl(func):
def wrapper(impl): def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl FlopProfiler._custom_flop_count_impl[func] = impl
return impl return impl
...@@ -16,7 +13,6 @@ def register_flop_count_impl(func): ...@@ -16,7 +13,6 @@ def register_flop_count_impl(func):
def register_shape_impl(func): def register_shape_impl(func):
def wrapper(impl): def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl ShapeProp._custom_dispatch_func[func] = impl
return impl return impl
......
...@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl ...@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
__all__ = [] __all__ = []
@register_tracer_impl(F.linear, name='_bias_addition_impl') @register_tracer_impl(F.linear, name="_bias_addition_impl")
def linear_impl(input, weight, bias=None): def linear_impl(input, weight, bias=None):
if bias is None: if bias is None:
return F.linear(input, weight) return F.linear(input, weight)
...@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None): ...@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return F.linear(input, weight) + bias return F.linear(input, weight) + bias
@register_tracer_impl(F.conv1d, name='_bias_addition_impl') @register_tracer_impl(F.conv1d, name="_bias_addition_impl")
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None: if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else: else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1)) (-1, 1)
)
@register_tracer_impl(F.conv2d, name='_bias_addition_impl') @register_tracer_impl(F.conv2d, name="_bias_addition_impl")
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None: if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else: else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1)) (-1, 1, 1)
)
@register_tracer_impl(F.conv3d, name='_bias_addition_impl') @register_tracer_impl(F.conv3d, name="_bias_addition_impl")
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None: if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else: else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1)) (-1, 1, 1, 1)
)
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input, @register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
weight, def conv_transpose1d_impl(
bias=None, input,
stride=_single(1), weight,
padding=_single(0), bias=None,
output_padding=_single(0), stride=_single(1),
groups=1, padding=_single(0),
dilation=_single(1)): output_padding=_single(0),
groups=1,
dilation=_single(1),
):
if bias is None: if bias is None:
return F.conv_transpose1d(input, return F.conv_transpose1d(
weight, input,
stride=stride, weight,
padding=padding, stride=stride,
output_padding=output_padding, padding=padding,
groups=groups, output_padding=output_padding,
dilation=dilation) groups=groups,
dilation=dilation,
)
else: else:
return F.conv_transpose1d(input, return F.conv_transpose1d(
weight, input,
stride=stride, weight,
padding=padding, stride=stride,
output_padding=output_padding, padding=padding,
groups=groups, output_padding=output_padding,
dilation=dilation) + bias.reshape((-1, 1)) groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight, @register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
bias=None, def conv_transpose2d_impl(
stride=_pair(1), input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
padding=_pair(0), ):
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None: if bias is None:
return F.conv_transpose2d(input, return F.conv_transpose2d(
weight, input,
stride=stride, weight,
padding=padding, stride=stride,
output_padding=output_padding, padding=padding,
groups=groups, output_padding=output_padding,
dilation=dilation) groups=groups,
dilation=dilation,
)
else: else:
return F.conv_transpose2d(input, return F.conv_transpose2d(
weight, input,
stride=stride, weight,
padding=padding, stride=stride,
output_padding=output_padding, padding=padding,
groups=groups, output_padding=output_padding,
dilation=dilation) + bias.reshape((-1, 1, 1)) groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight, @register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
bias=None, def conv_transpose3d_impl(
stride=_triple(1), input,
padding=_triple(0), weight,
output_padding=_triple(0), bias=None,
groups=1, stride=_triple(1),
dilation=_triple(1)): padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1),
):
if bias is None: if bias is None:
return F.conv_transpose3d(input, return F.conv_transpose3d(
weight, input,
stride=stride, weight,
padding=padding, stride=stride,
output_padding=output_padding, padding=padding,
groups=groups, output_padding=output_padding,
dilation=dilation) groups=groups,
dilation=dilation,
)
else: else:
return F.conv_transpose3d(input, return F.conv_transpose3d(
weight, input,
stride=stride, weight,
padding=padding, stride=stride,
output_padding=output_padding, padding=padding,
groups=groups, output_padding=output_padding,
dilation=dilation) + bias.reshape((-1, 1, 1, 1)) groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
def addmm_impl(input, mat1, mat2, beta=1, alpha=1): def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1: if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
...@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1): ...@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return F.linear(mat1, mat2.transpose(0, 1)) + input return F.linear(mat1, mat2.transpose(0, 1)) + input
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl') @register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl') @register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1): def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1: if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
......
...@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl ...@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
try: try:
import apex import apex
register_leaf_module(apex.normalization.FusedLayerNorm) register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm) register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm) register_leaf_module(apex.normalization.MixedFusedLayerNorm)
......
import operator import operator
from typing import Any, Callable, Dict, Optional, Set, Union from typing import Any, Callable, Dict, Optional, Union
import torch import torch
import torch.nn as nn from torch.fx import Node, Proxy
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
...@@ -32,7 +30,7 @@ class ColoProxy(Proxy): ...@@ -32,7 +30,7 @@ class ColoProxy(Proxy):
def __torch_function__(cls, orig_method, types, args=(), kwargs=None): def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch: if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs) proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl cls._func_dispatch[orig_method] = impl
return proxy return proxy
...@@ -72,7 +70,7 @@ class ColoProxy(Proxy): ...@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k, getattr(self._meta_data, k, None)) return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value): def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data proxy.meta_data = self._meta_data
return proxy return proxy
...@@ -89,7 +87,6 @@ class ColoProxy(Proxy): ...@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
class ColoAttribute(ColoProxy): class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None): def __init__(self, root, attr: str, data=None):
self.root = root self.root = root
self.attr = attr self.attr = attr
...@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy): ...@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls # the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call # which do not rely on the getitem call
if self._node is None: if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node return self._node
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self): def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})" return f"ColoAttribute({self.node.name}, attr={self.attr})"
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Union
import torch import torch
from torch.fx import Tracer from torch.fx import Tracer
...@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor ...@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
try: try:
from ..codegen import ActivationCheckpointCodeGen from ..codegen import ActivationCheckpointCodeGen
SUPPORT_ACTIVATION = True SUPPORT_ACTIVATION = True
except: except:
SUPPORT_ACTIVATION = False SUPPORT_ACTIVATION = False
...@@ -16,7 +17,7 @@ from .tracer import ColoTracer ...@@ -16,7 +17,7 @@ from .tracer import ColoTracer
def _default_device(): def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module): def _current_device(module: torch.nn.Module):
...@@ -144,10 +145,9 @@ def symbolic_trace( ...@@ -144,10 +145,9 @@ def symbolic_trace(
if meta_args: if meta_args:
device, orig_device = _default_device(), _current_device(root) device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
bias_addition_split=bias_addition_split).trace(root.to(device), root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
concrete_args=concrete_args, )
meta_args=tree_map(wrap_fn, meta_args))
if trace_act_ckpt and SUPPORT_ACTIVATION: if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen()) graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device) root.to(orig_device)
......
...@@ -20,11 +20,10 @@ def _truncate_suffix(s: str): ...@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import re import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name # FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return re.sub(r'_\d+$', '', s) return re.sub(r"_\d+$", "", s)
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'): def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
def wrapper(impl): def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}" assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl getattr(ColoTracer, name)[func] = impl
...@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo ...@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
def register_leaf_module_impl(module: nn.Module): def register_leaf_module_impl(module: nn.Module):
def wrapper(impl): def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl ColoTracer._custom_leaf_module_impl[module] = impl
return impl return impl
...@@ -76,7 +74,7 @@ class ColoTracer(Tracer): ...@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
self.ckpt_regions = [] self.ckpt_regions = []
self.ckpt_idx = 0 self.ckpt_idx = 0
self.mod_dir = '' self.mod_dir = ""
# whether the tracer should split the bias_add ops into two ops # whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split self.bias_addition_split = bias_addition_split
...@@ -87,35 +85,41 @@ class ColoTracer(Tracer): ...@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None: if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False return False
# user can specify which modules are leaf modules and which are not # user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module return type(m) not in self._custom_non_leaf_module and (
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name))) type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
)
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], def call_module(
kwargs: Dict[str, Any]) -> Any: self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
curr_dir = self.mod_dir curr_dir = self.mod_dir
self.mod_dir = 'self.' + self.path_of_module(m) self.mod_dir = "self." + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs) rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir self.mod_dir = curr_dir
return rst return rst
def proxy(self, node: Node) -> 'ColoProxy': def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self) return ColoProxy(node, self)
def create_proxy(self, def create_proxy(
kind: str, self,
target: Target, kind: str,
args: Tuple[Any, ...], target: Target,
kwargs: Dict[str, Any], args: Tuple[Any, ...],
name: Optional[str] = None, kwargs: Dict[str, Any],
type_expr: Optional[Any] = None, name: Optional[str] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None): type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder': if kind == "placeholder":
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( proxy.meta_data = (
_truncate_suffix(target), None) self.meta_args[target]
elif kind == 'get_attr': if target in self.meta_args
else self.concrete_args.get(_truncate_suffix(target), None)
)
elif kind == "get_attr":
self.disable_module_getattr = True self.disable_module_getattr = True
try: try:
attr_itr = self.root attr_itr = self.root
...@@ -125,20 +129,21 @@ class ColoTracer(Tracer): ...@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
proxy.meta_data = attr_itr proxy.meta_data = attr_itr
finally: finally:
self.disable_module_getattr = False self.disable_module_getattr = False
elif kind == 'call_function': elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method': elif kind == "call_method":
self.disable_module_getattr = True self.disable_module_getattr = True
try: try:
if target == '__call__': if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else: else:
if target not in _TensorPropertyMethod: if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
**tree_map(unwrap_fn, kwargs)) *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
)
finally: finally:
self.disable_module_getattr = False self.disable_module_getattr = False
elif kind == 'call_module': elif kind == "call_module":
mod = self.root.get_submodule(target) mod = self.root.get_submodule(target)
self.disable_module_getattr = True self.disable_module_getattr = True
try: try:
...@@ -158,11 +163,12 @@ class ColoTracer(Tracer): ...@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node return node
def trace(self, def trace(
root: torch.nn.Module, self,
concrete_args: Optional[Dict[str, torch.Tensor]] = None, root: torch.nn.Module,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None,
) -> Graph:
if meta_args is None: if meta_args is None:
meta_args = {} meta_args = {}
...@@ -177,9 +183,7 @@ class ColoTracer(Tracer): ...@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
non_concrete_arg_names = sig_names - concrete_arg_names non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values # update concrete args with default values
for k, v in sig.parameters.items(): for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \ if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]): def _check_arg_name_valid(names: Iterable[str]):
...@@ -194,9 +198,9 @@ class ColoTracer(Tracer): ...@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
self.meta_args = meta_args self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad(): with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
self.mod_dir = 'self' self.mod_dir = "self"
self.graph = super().trace(root, concrete_args=concrete_args) self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = '' self.mod_dir = ""
self.graph.lint() self.graph.lint()
for node in self.graph.nodes: for node in self.graph.nodes:
...@@ -266,17 +270,17 @@ class ColoTracer(Tracer): ...@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
# override the torch factory functions to create a proxy when the method # override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``. # is called during ``symbolic_trace()``.
def wrap_factory_method(target): def wrap_factory_method(target):
@functools.wraps(target) @functools.wraps(target)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values()) isinstance(p, ColoProxy) for p in kwargs.values()
)
if is_proxy: if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy # if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy # e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True self.disable_module_getattr = True
try: try:
proxy = self.create_proxy('call_function', target, args, kwargs) proxy = self.create_proxy("call_function", target, args, kwargs)
finally: finally:
self.disable_module_getattr = False self.disable_module_getattr = False
return proxy return proxy
...@@ -341,10 +345,13 @@ class ColoTracer(Tracer): ...@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
if attr_val is p: if attr_val is p:
if n not in parameter_proxy_cache: if n not in parameter_proxy_cache:
kwargs = {} kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else kwargs["proxy_factory_fn"] = (
lambda node: ColoProxy(self, node, n, attr_val)) None
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] if not self.param_shapes_constant
else lambda node: ColoProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n] return parameter_proxy_cache[n]
return None return None
...@@ -355,8 +362,9 @@ class ColoTracer(Tracer): ...@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
return maybe_buffer_proxy return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter): if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), maybe_parameter_proxy = maybe_get_proxy_for_attr(
parameter_proxy_cache) attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None: if maybe_parameter_proxy is not None:
return maybe_parameter_proxy return maybe_parameter_proxy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment