Unverified Commit a4c91be7 authored by superhero-7's avatar superhero-7 Committed by GitHub
Browse files

Modified altdiffusion pipline to support altdiffusion-m18 (#2993)



* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

---------
Co-authored-by: default avatarroot <fulong_ye@163.com>
parent 3becd368
...@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig): ...@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
base_model_prefix = "roberta" base_model_prefix = "roberta"
config_class = RobertaSeriesConfig config_class = RobertaSeriesConfig
...@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): ...@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
super().__init__(config) super().__init__(config)
self.roberta = XLMRobertaModel(config) self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size, config.project_dim) self.transformation = nn.Linear(config.hidden_size, config.project_dim)
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
if self.has_pre_transformation:
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init() self.post_init()
def forward( def forward(
...@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): ...@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
projection_state = self.transformation(outputs.last_hidden_state) if self.has_pre_transformation:
sequence_output2 = outputs["hidden_states"][-2]
return TransformationModelOutput( sequence_output2 = self.pre_LN(sequence_output2)
projection_state=projection_state, projection_state2 = self.transformation_pre(sequence_output2)
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states, return TransformationModelOutput(
attentions=outputs.attentions, projection_state=projection_state2,
) last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
projection_state = self.transformation(outputs.last_hidden_state)
return TransformationModelOutput(
projection_state=projection_state,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment