"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cac877425c142f4ae7b99f1601ab49f7c29ed56f"
Unverified Commit 1562c04e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

FlaxBartPretrainedModel -> FlaxBartPreTrainedModel (#12313)

parent ebe54135
...@@ -911,7 +911,7 @@ class FlaxBartModule(nn.Module): ...@@ -911,7 +911,7 @@ class FlaxBartModule(nn.Module):
) )
class FlaxBartPretrainedModel(FlaxPreTrainedModel): class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
config_class = BartConfig config_class = BartConfig
base_model_prefix: str = "model" base_model_prefix: str = "model"
module_class: nn.Module = None module_class: nn.Module = None
...@@ -1232,7 +1232,7 @@ class FlaxBartPretrainedModel(FlaxPreTrainedModel): ...@@ -1232,7 +1232,7 @@ class FlaxBartPretrainedModel(FlaxPreTrainedModel):
"The bare Bart Model transformer outputting raw hidden-states without any specific head on top.", "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class FlaxBartModel(FlaxBartPretrainedModel): class FlaxBartModel(FlaxBartPreTrainedModel):
config: BartConfig config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
module_class = FlaxBartModule module_class = FlaxBartModule
...@@ -1318,7 +1318,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module): ...@@ -1318,7 +1318,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
@add_start_docstrings( @add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
) )
class FlaxBartForConditionalGeneration(FlaxBartPretrainedModel): class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
module_class = FlaxBartForConditionalGenerationModule module_class = FlaxBartForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -1623,7 +1623,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module): ...@@ -1623,7 +1623,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class FlaxBartForSequenceClassification(FlaxBartPretrainedModel): class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel):
module_class = FlaxBartForSequenceClassificationModule module_class = FlaxBartForSequenceClassificationModule
dtype = jnp.float32 dtype = jnp.float32
...@@ -1710,7 +1710,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module): ...@@ -1710,7 +1710,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class FlaxBartForQuestionAnswering(FlaxBartPretrainedModel): class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel):
module_class = FlaxBartForQuestionAnsweringModule module_class = FlaxBartForQuestionAnsweringModule
dtype = jnp.float32 dtype = jnp.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment