Commit 58df7205 authored by Frank Lee's avatar Frank Lee
Browse files

[shardformer] adapted T5 and LLaMa test to use kit (#4049)

* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
parent 4021b9a8
import copy
from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(world_size, model_fn):
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
import copy
import os import os
import random
import pytest import pytest
import torch import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
def build_model(world_size, model_fn):
# create new model
config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128)
org_model = model_fn(config).cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
del tokenized_input["token_type_ids"]
del tokenized_input["attention_mask"]
# switch to train mode
org_model.train()
sharded_model.train()
if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
org_output = org_model(**tokenized_input)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, LlamaForCausalLM):
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
org_output = org_model(**tokenized_input)
org_loss = org_output.loss
shard_output = sharded_model(**tokenized_input)
shard_loss = shard_output.loss
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
# forward check
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
# run backward # run backward
...@@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model): ...@@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model):
shard_loss.backward() shard_loss.backward()
# check grad # check grad
if isinstance(org_model, LlamaModel): if hasattr(org_model, 'model'):
llama_model = org_model
shard_llama_model = sharded_model
else:
llama_model = org_model.model llama_model = org_model.model
shard_llama_model = sharded_model.model shard_llama_model = sharded_model.model
else:
llama_model = org_model
shard_llama_model = sharded_model
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
...@@ -89,17 +47,11 @@ def check_llama(rank, world_size, port): ...@@ -89,17 +47,11 @@ def check_llama(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [ sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
LlamaModel,
# LlamaForCausalLM,
# TODO: do not work yet
# LlamaForSequenceClassification
]
for model_fn in model_list: for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
import copy
import os import os
import pytest import pytest
import torch import torch
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
tokenizer = T5Tokenizer.from_pretrained("t5-small") # check forward
# the value "past_key_values" is sharded, so we ignore
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
def build_model(world_size, model_fn): output_transform_fn, loss_fn)
config = T5Config(decoder_start_token_id=0)
config.dropout_rate = 0
org_model = model_fn(config=config).to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size)
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
def check_forward_backward(org_model, sharded_model):
# prepare input
input_ids = tokenizer("translate English to German: The house is wonderful.",
return_tensors="pt").input_ids.to('cuda')
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
# switch to train mode
org_model.train()
sharded_model.train()
if isinstance(org_model, T5ForConditionalGeneration):
org_output = org_model(input_ids=input_ids, labels=labels)
org_loss = org_output.loss
shard_output = sharded_model(input_ids=input_ids, labels=labels)
shard_loss = shard_output.loss
elif isinstance(org_model, T5Model):
decoder_input_ids = org_model._shift_right(input_ids)
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
shard_loss = shard_output.last_hidden_state.mean()
elif isinstance(org_model, T5EncoderModel):
org_output = org_model(input_ids=input_ids)
org_loss = org_output.last_hidden_state.mean()
shard_output = sharded_model(input_ids=input_ids)
shard_loss = shard_output.last_hidden_state.mean()
# key is sharded, so we ignore
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward # do backward
...@@ -81,18 +37,15 @@ def check_forward_backward(org_model, sharded_model): ...@@ -81,18 +37,15 @@ def check_forward_backward(org_model, sharded_model):
def check_t5(rank, world_size, port): def check_t5(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_fn_list = [ sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
T5Model,
T5ForConditionalGeneration,
T5EncoderModel,
]
for model_fn in model_fn_list: for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment