Commit 0a2bea47 authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Fix repo consistency

parent 0645b07d
...@@ -349,8 +349,9 @@ class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast): ...@@ -349,8 +349,9 @@ class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast):
self._config.hidden_size // num_encoder_attention_heads, self._config.hidden_size // num_encoder_attention_heads,
) )
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat( common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
) )
common_inputs["past_key_values"] = [ common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment