Commit c59d7aca authored by Kun Lin's avatar Kun Lin Committed by Hongxin Liu
Browse files

Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout
parent 0ceec8f9
from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
...@@ -36,7 +35,7 @@ class ViTPolicy(Policy): ...@@ -36,7 +35,7 @@ class ViTPolicy(Policy):
suffix="dropout", suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
) )
]) ])
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads": "attention.attention.num_attention_heads":
...@@ -44,45 +43,47 @@ class ViTPolicy(Policy): ...@@ -44,45 +43,47 @@ class ViTPolicy(Policy):
"attention.attention.all_head_size": "attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
}, },
param_replacement=[], param_replacement=[],
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.query", suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.key", suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.value", suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.dropout", suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dropout", suffix="attention.output.dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dropout", suffix="output.dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
), ),
]) ])
return policy
return policy return policy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment