"docs/source/en/model_doc/data2vec.md" did not exist on "9d99489f2f79b81fa9131c9299c236006dff94fb"
Unverified Commit 3224c0c1 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Remove some Kosmos-2 `copied from` (#27149)



* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent cd19b193
...@@ -52,7 +52,6 @@ KOSMOS2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -52,7 +52,6 @@ KOSMOS2_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
...@@ -67,7 +66,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -67,7 +66,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
): ):
...@@ -660,7 +658,7 @@ class Kosmos2VisionEncoder(nn.Module): ...@@ -660,7 +658,7 @@ class Kosmos2VisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1114,7 +1112,6 @@ class Kosmos2TextTransformer(nn.Module): ...@@ -1114,7 +1112,6 @@ class Kosmos2TextTransformer(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
...@@ -1268,7 +1265,7 @@ class Kosmos2TextTransformer(nn.Module): ...@@ -1268,7 +1265,7 @@ class Kosmos2TextTransformer(nn.Module):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1428,11 +1425,6 @@ class Kosmos2PreTrainedModel(PreTrainedModel): ...@@ -1428,11 +1425,6 @@ class Kosmos2PreTrainedModel(PreTrainedModel):
if module.embed_tokens.padding_idx is not None: if module.embed_tokens.padding_idx is not None:
module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_() module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Kosmos2TextTransformer, Kosmos2VisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class Kosmos2VisionModel(Kosmos2PreTrainedModel): class Kosmos2VisionModel(Kosmos2PreTrainedModel):
config_class = Kosmos2VisionConfig config_class = Kosmos2VisionConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment