"docs/source/vscode:/vscode.git/clone" did not exist on "cd4f02bed8f3dccd22ab49d67ba96a5147a48bc0"
Unverified Commit 20190b49 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] to fix whisper test failed due to significant accuracy differences. (#4710)

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed
parent e2c0e7f9
...@@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer: ...@@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer:
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | | whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] |
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
......
...@@ -57,6 +57,11 @@ class WhisperPolicy(Policy): ...@@ -57,6 +57,11 @@ class WhisperPolicy(Policy):
warnings.warn( warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
#TODO using the jit fused add_and_dropout affect the accuracy
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim": "self_attn.embed_dim":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment