Unverified Commit c3d5fa3b authored by eric8607242's avatar eric8607242 Committed by GitHub
Browse files

[shardformer] Support customized policy for llamav2 based model with HybridParallelPlugin (#4624)



* Enable policy assignment in HybridPlugin and enable llama policy for llamav2

* Remove Policy from Plugin

* revert changes of plugin

HybridParallelModule

* revert changes in plugin

* upgrade transformers

* revert transformers version

---------
Co-authored-by: default avatarflybird11111 <1829166702@qq.com>
parent 9709b8f5
......@@ -40,14 +40,20 @@ class LlamaPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
decoder_attribute_replacement = {
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment