Unverified Commit 3353e55c authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
parent e04436a8
...@@ -187,6 +187,9 @@ class BertPipelineForwards: ...@@ -187,6 +187,9 @@ class BertPipelineForwards:
hidden_states = split_forward_gather_backward(hidden_states, hidden_states = split_forward_gather_backward(hidden_states,
dim=1, dim=1,
process_group=shard_config.tensor_parallel_process_group) process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0: if stage_manager.is_first_stage() and idx == 0:
...@@ -1241,6 +1244,9 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -1241,6 +1244,9 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
embedding_output = split_forward_gather_backward(embedding_output, embedding_output = split_forward_gather_backward(embedding_output,
dim=1, dim=1,
process_group=shard_config.tensor_parallel_process_group) process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
......
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
...@@ -35,6 +36,10 @@ class LlamaPolicy(Policy): ...@@ -35,6 +36,10 @@ class LlamaPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription( policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
......
...@@ -104,16 +104,20 @@ class OPTPolicy(Policy): ...@@ -104,16 +104,20 @@ class OPTPolicy(Policy):
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(), 'forward': get_opt_flash_attention_forward(),
}) },
policy=policy,
target_key=OPTAttention)
# use jit fused operator # use jit fused operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(), 'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy=policy,
target_key=OPTDecoderLayer)
return policy return policy
......
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
...@@ -59,6 +60,10 @@ class T5BasePolicy(Policy): ...@@ -59,6 +60,10 @@ class T5BasePolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
......
import warnings
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
...@@ -32,6 +33,10 @@ class ViTPolicy(Policy): ...@@ -32,6 +33,10 @@ class ViTPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[], param_replacement=[],
......
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Tuple from typing import Callable, Dict, List, Tuple
...@@ -33,7 +34,6 @@ class WhisperPolicy(Policy): ...@@ -33,7 +34,6 @@ class WhisperPolicy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
# TODO:
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0: if vocab_size % world_size != 0:
...@@ -52,6 +52,14 @@ class WhisperPolicy(Policy): ...@@ -52,6 +52,14 @@ class WhisperPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim": "self_attn.embed_dim":
...@@ -198,20 +206,11 @@ class WhisperPolicy(Policy): ...@@ -198,20 +206,11 @@ class WhisperPolicy(Policy):
# enable flash attention # enable flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(), 'forward': get_whisper_flash_attention_forward(),
}) },
policy=policy,
# use jit fused operator target_key=WhisperAttention)
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
......
...@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3 atol, rtol = 2e-4, 2e-4
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
...@@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights and gradients # check weights and gradients
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3 atol, rtol = 2e-4, 2e-4
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
...@@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step() org_optimizer.step()
sharded_optimizer.step() sharded_optimizer.step()
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3 atol, rtol = 2e-4, 2e-4
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
...@@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# TODO(jianghai) fix fp16 # TODO(jianghai) fix fp16
#TODO fix WhisperForConditionalGeneration enable jit fused operator
@parameterize('test_config', [{ @parameterize('test_config', [{
'tp_size': 2, 'tp_size': 2,
'pp_size': 2, 'pp_size': 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment