Commit d2cd48e0 authored by flybird1111's avatar flybird1111 Committed by Hongxin Liu
Browse files

[shardformer] test all optimizations (#4399)

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] test all optimizations
parent 7a3dfd0c
...@@ -148,7 +148,10 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -148,7 +148,10 @@ class HybridParallelPlugin(PipelinePluginBase):
precision: str = 'fp16', precision: str = 'fp16',
zero_stage: int = 0, zero_stage: int = 0,
cpu_offload: bool = False, cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False, enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
...@@ -171,7 +174,10 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -171,7 +174,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.precision = precision self.precision = precision
self.zero_stage = zero_stage self.zero_stage = zero_stage
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
...@@ -186,7 +192,10 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -186,7 +192,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1, enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization) enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused)
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
......
...@@ -19,4 +19,4 @@ ninja ...@@ -19,4 +19,4 @@ ninja
flash_attn>=2.0 flash_attn>=2.0
datasets datasets
ninja ninja
flash-attn flash-attn>=2.0
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -16,8 +15,8 @@ from colossalai.booster.plugin import HybridParallelPlugin ...@@ -16,8 +15,8 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_ from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
...@@ -156,10 +155,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo ...@@ -156,10 +155,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
else: else:
data = {k: v.cuda() for k, v in data.items()} data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data) sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output) sharded_loss = criterion(sharded_output)
sharded_loss.backward() sharded_optimizer.backward(sharded_loss)
org_model.train() org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data) org_output = org_model(**data)
org_loss = criterion(org_output) org_loss = criterion(org_output)
org_loss.backward() org_loss.backward()
...@@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor, ...@@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor,
if stage_manager and stage_manager.is_last_stage(): if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \ assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \ assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
...@@ -213,7 +214,7 @@ def check_weight(org_model: Module, ...@@ -213,7 +214,7 @@ def check_weight(org_model: Module,
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \ assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
...@@ -244,6 +245,7 @@ def check_grad(org_model: Module, ...@@ -244,6 +245,7 @@ def check_grad(org_model: Module,
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert torch.allclose( assert torch.allclose(
org_grad, shard_grad, rtol=rtol, atol=atol org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from torch import distributed as dist from torch import distributed as dist
import colossalai import colossalai
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
...@@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'GPT2Model': if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) # check loss
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if module.__class__.__name__ == 'GPT2Model':
return module
return module.transformer
# unwrap model # unwrap model
if org_model.__class__.__name__ == 'GPT2Model': gpt2 = unwrap(org_model)
gpt2 = org_model sharded_gpt2 = unwrap(sharded_model)
sharded_gpt2 = sharded_model.unwrap()
else:
gpt2 = org_model.transformer
sharded_gpt2 = sharded_model.unwrap().transformer
col_layer_for_check = ['h[0].mlp.c_fc'] col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad # check grad
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step() # check weights after optimizer.step()
org_optimizer.step() org_optimizer.step()
sharded_optimizer.step() sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2, 'tp_size': 2,
'pp_size': 2, 'pp_size': 2,
'num_microbatches': 4, 'num_microbatches': 4,
'enable_fused_normalization': True, 'enable_all_optimization': True,
'use_lazy_init': True 'use_lazy_init': True,
'precision': 'fp32',
}, { }, {
'tp_size': 1, 'tp_size': 1,
'pp_size': 2, 'pp_size': 2,
'num_microbatches': 4, 'num_microbatches': 4,
'use_lazy_init': False 'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp16',
'initial_scale': 1,
}, { }, {
'tp_size': 4, 'tp_size': 4,
'pp_size': 1, 'pp_size': 1,
'enable_fused_normalization': True, 'enable_all_optimization': True,
'use_lazy_init': False 'use_lazy_init': False,
'precision': 'fp32',
}]) }])
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it # TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} # TODO: check and debug TP+AMP
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment