"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d923f76203e742c4497eb061968a587c28062e74"
Unverified Commit 8b9ae455 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Set scale_embedding to False in some TF tests (#15952)



* set scale_embedding to False to avoid large (> 1e-5) output differences between PT/TF
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 38cc3506
...@@ -90,6 +90,7 @@ class TFSpeech2TextModelTester: ...@@ -90,6 +90,7 @@ class TFSpeech2TextModelTester:
eos_token_id=2, eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
scale_embedding=False,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -115,6 +116,7 @@ class TFSpeech2TextModelTester: ...@@ -115,6 +116,7 @@ class TFSpeech2TextModelTester:
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.scale_embedding = scale_embedding
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_features = floats_tensor( input_features = floats_tensor(
...@@ -155,6 +157,7 @@ class TFSpeech2TextModelTester: ...@@ -155,6 +157,7 @@ class TFSpeech2TextModelTester:
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
scale_embedding=self.scale_embedding,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment