Unverified Commit 8739aa7f authored by Jianghai's avatar Jianghai Committed by GitHub
Browse files

[shardformer] Pipeline/whisper (#4456)

* add some base tests and policies

* finish whisper base model

* add conditional generation

* finish basic tests

* whisper

* finish whisper

* finish whisper

* del useless  whisper test

* fix

* add argmin to replace

* finish revision
parent a27e0bb4
This diff is collapsed.
...@@ -304,15 +304,6 @@ class BlipPolicy(Policy): ...@@ -304,15 +304,6 @@ class BlipPolicy(Policy):
return policy return policy
def postprocess(self): def postprocess(self):
binding_map = {
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model return self.model
......
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from torch import Tensor, nn from torch import Tensor, nn
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
...@@ -228,13 +229,7 @@ class T5BasePolicy(Policy): ...@@ -228,13 +229,7 @@ class T5BasePolicy(Policy):
def objective(num_encoder_stages): def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = 0 num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
optimal_diff = 2**31 - 1
for i in range(1, num_stages):
attempt = objective(i)
if attempt < optimal_diff:
num_encoder_stages = i
optimal_diff = attempt
num_decoder_stages = num_stages - num_encoder_stages num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
......
from functools import partial
from typing import Callable, Dict, List, Tuple
import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import Tensor
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import ( from ..modeling.whisper import (
WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward, get_whisper_flash_attention_forward,
...@@ -12,7 +18,8 @@ from ..modeling.whisper import ( ...@@ -12,7 +18,8 @@ from ..modeling.whisper import (
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
'WhisperForAudioClassificationPolicy'
] ]
...@@ -223,6 +230,146 @@ class WhisperPolicy(Policy): ...@@ -223,6 +230,146 @@ class WhisperPolicy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
@staticmethod
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute whisper layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
raise ValueError("The number of encoder layers for whisper must be a positive integer.")
# number of layers should be large enough to fill in every stage
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optmized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
# whisper for audio classification holds encoder only
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
held_layers = []
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage():
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.layer_norm)
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if stage_manager.is_last_stage():
held_layers.append(decoder.layer_norm)
held_layers.extend(decoder.layers[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# WhisperModel # WhisperModel
class WhisperModelPolicy(WhisperPolicy): class WhisperModelPolicy(WhisperPolicy):
...@@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy): ...@@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self):
from transformers import WhisperModel
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperModel,
new_forward=WhisperPipelineForwards.whisper_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"no shared params in whisper model"
return []
# WhisperForConditionalGeneration # WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy): class WhisperForConditionalGenerationPolicy(WhisperPolicy):
...@@ -238,20 +403,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): ...@@ -238,20 +403,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() from transformers import WhisperForConditionalGeneration
module_policy = self.add_lm_head_policy(module_policy) policy = super().module_policy()
return module_policy policy = self.add_lm_head_policy(policy)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
policy=policy)
return policy
def postprocess(self): def postprocess(self):
binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model return self.model
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.proj_out)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
model = module.model
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
stage_manager.num_stages)
shared_params = []
shared_embedding = {}
if id(module.proj_out) == id(model.decoder.embed_tokens):
shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
shared_embedding[stage_manager.num_stages - 1] = module.proj_out
if len(shared_embedding) > 0:
shared_params.append(shared_embedding)
return shared_params
return []
# WhisperForAudioClassification # WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy): class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def preprocess(self):
return self.model
def module_policy(self):
from transformers import WhisperForAudioClassification
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
from colossalai.shardformer.policies.whisper import WhisperPolicy
def test_whisper_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_whisper_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
if __name__ == '__main__':
test_whisper_pipeline_distribution()
test_whisper_pipeline_layers()
...@@ -6,6 +6,7 @@ from torch import distributed as dist ...@@ -6,6 +6,7 @@ from torch import distributed as dist
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -143,6 +144,7 @@ def run_llama_test(test_config): ...@@ -143,6 +144,7 @@ def run_llama_test(test_config):
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -3,6 +3,8 @@ import torch ...@@ -3,6 +3,8 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import ( from colossalai.testing import (
assert_hf_output_close, assert_hf_output_close,
clear_cache_before_run, clear_cache_before_run,
...@@ -11,55 +13,145 @@ from colossalai.testing import ( ...@@ -11,55 +13,145 @@ from colossalai.testing import (
spawn, spawn,
) )
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# check forward # check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
output_transform_fn, loss_fn) build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
org_loss, org_output, sharded_loss, sharded_output = \
# do backward run_forward_backward_with_hybrid_plugin(
org_loss.backward() org_model,
shard_loss.backward() sharded_model,
sharded_optimizer,
assert torch.allclose(org_loss, shard_loss, data_gen_fn,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwarp the model # unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model whisper = org_model.model
sharded_whisper = sharded_model.model sharded_whisper = sharded_model.unwrap().model
else: else:
whisper = org_model whisper = org_model
sharded_whisper = sharded_model sharded_whisper = sharded_model.unwrap()
# check grad # check grad
if org_model.__class__.__name__ == 'WhisperForAudioClassification': if org_model.__class__.__name__ == 'WhisperForAudioClassification':
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
else: else:
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj'] col_layer_for_check = [
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj'] 'encoder.layers[0].self_attn.q_proj',
check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) # 'decoder.layers[0].self_attn.q_proj'
check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) ]
row_layer_for_check = [
'encoder.layers[0].self_attn.out_proj',
#'decoder.layers[0].self_attn.out_proj'
]
# check weights and gradients
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_weight(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False]) # TODO(jianghai) fix fp16
@parameterize('enable_flash_attention', [True, False]) @parameterize('test_config', [{
@parameterize('enable_jit_fused', [True, False]) 'tp_size': 2,
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): 'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
}])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port): ...@@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()
def test_whisper(): def test_whisper():
spawn(check_whisper, 2) spawn(check_whisper, 4)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment