Unverified Commit 53c710d1 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix failing torchscript tests for `CpmAnt` model (#22766)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d2ffc3fc
...@@ -69,8 +69,6 @@ class CpmAntConfig(PretrainedConfig): ...@@ -69,8 +69,6 @@ class CpmAntConfig(PretrainedConfig):
Whether to use cache. Whether to use cache.
init_std (`float`, *optional*, defaults to 1.0): init_std (`float`, *optional*, defaults to 1.0):
Initialize parameters with std = init_std. Initialize parameters with std = init_std.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Example: Example:
...@@ -105,7 +103,6 @@ class CpmAntConfig(PretrainedConfig): ...@@ -105,7 +103,6 @@ class CpmAntConfig(PretrainedConfig):
prompt_length: int = 32, prompt_length: int = 32,
segment_types: int = 32, segment_types: int = 32,
use_cache: bool = True, use_cache: bool = True,
return_dict: bool = True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -123,5 +120,4 @@ class CpmAntConfig(PretrainedConfig): ...@@ -123,5 +120,4 @@ class CpmAntConfig(PretrainedConfig):
self.eps = eps self.eps = eps
self.use_cache = use_cache self.use_cache = use_cache
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.return_dict = return_dict
self.init_std = init_std self.init_std = init_std
...@@ -378,7 +378,7 @@ class CpmAntEncoder(nn.Module): ...@@ -378,7 +378,7 @@ class CpmAntEncoder(nn.Module):
""" """
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
current_key_values = [] if use_cache else None current_key_values = () if use_cache else None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
...@@ -395,7 +395,7 @@ class CpmAntEncoder(nn.Module): ...@@ -395,7 +395,7 @@ class CpmAntEncoder(nn.Module):
if output_attentions: if output_attentions:
all_self_attns += (attn_weights,) all_self_attns += (attn_weights,)
if current_key_value is not None: if current_key_value is not None:
current_key_values.append(current_key_value) current_key_values = current_key_values + (current_key_value,)
hidden_states = self.output_layernorm(hidden_states) hidden_states = self.output_layernorm(hidden_states)
...@@ -659,7 +659,7 @@ class CpmAntModel(CpmAntPreTrainedModel): ...@@ -659,7 +659,7 @@ class CpmAntModel(CpmAntPreTrainedModel):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
# add prompts ahead # add prompts ahead
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment