Unverified Commit eeaa9c01 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make CLIP model could use new added tokens with meaningful pooling (#24777)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d0154015
...@@ -701,6 +701,9 @@ class CLIPTextTransformer(nn.Module): ...@@ -701,6 +701,9 @@ class CLIPTextTransformer(nn.Module):
self.encoder = CLIPEncoder(config) self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward( def forward(
...@@ -750,13 +753,26 @@ class CLIPTextTransformer(nn.Module): ...@@ -750,13 +753,26 @@ class CLIPTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width] if self.eos_token_id == 2:
# take features from the eot embedding (eot_token is the highest number in each sequence) # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
pooled_output = last_hidden_state[ # ------------------------------------------------------------
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), # text_embeds.shape = [batch_size, sequence_length, transformer.width]
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), # take features from the eot embedding (eot_token is the highest number in each sequence)
] # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooled_output) + encoder_outputs[1:]
......
...@@ -487,6 +487,9 @@ class FlaxCLIPTextTransformer(nn.Module): ...@@ -487,6 +487,9 @@ class FlaxCLIPTextTransformer(nn.Module):
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype) self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# For `pooled_output` computation
self.eos_token_id = self.config.eos_token_id
def __call__( def __call__(
self, self,
input_ids, input_ids,
...@@ -517,9 +520,18 @@ class FlaxCLIPTextTransformer(nn.Module): ...@@ -517,9 +520,18 @@ class FlaxCLIPTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width] if self.eos_token_id == 2:
# take features from the EOS embedding (eos_token_id is the highest number in each sequence) # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)] # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
else:
# (no need to cast from bool to int after comparing to `eos_token_id`)
pooled_output = last_hidden_state[
jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1)
]
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooled_output) + encoder_outputs[1:]
......
...@@ -494,6 +494,9 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer): ...@@ -494,6 +494,9 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
epsilon=config.layer_norm_eps, name="final_layer_norm" epsilon=config.layer_norm_eps, name="final_layer_norm"
) )
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def call( def call(
self, self,
input_ids: TFModelInputType, input_ids: TFModelInputType,
...@@ -530,14 +533,30 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer): ...@@ -530,14 +533,30 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
sequence_output = self.final_layer_norm(inputs=sequence_output) sequence_output = self.final_layer_norm(inputs=sequence_output)
# text_embeds.shape = [batch_size, n_ctx, transformer.width] if self.eos_token_id == 2:
# take features from the eot embedding (eot_token is the highest number in each sequence) # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
pooled_output = tf.gather_nd( # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
params=sequence_output, # ------------------------------------------------------------
indices=tf.stack( # text_embeds.shape = [batch_size, n_ctx, transformer.width]
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 # take features from the eot embedding (eot_token is the highest number in each sequence)
), pooled_output = tf.gather_nd(
) params=sequence_output,
indices=tf.stack(
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
),
)
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(
tf.range(input_shape[0], dtype=tf.int64),
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
),
axis=1,
),
)
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
......
...@@ -97,8 +97,8 @@ class CLIPSegTextConfig(PretrainedConfig): ...@@ -97,8 +97,8 @@ class CLIPSegTextConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
initializer_factor=1.0, initializer_factor=1.0,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=49406,
eos_token_id=2, eos_token_id=49407,
**kwargs, **kwargs,
): ):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......
...@@ -712,6 +712,9 @@ class CLIPSegTextTransformer(nn.Module): ...@@ -712,6 +712,9 @@ class CLIPSegTextTransformer(nn.Module):
self.encoder = CLIPSegEncoder(config) self.encoder = CLIPSegEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg # Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg
...@@ -762,13 +765,26 @@ class CLIPSegTextTransformer(nn.Module): ...@@ -762,13 +765,26 @@ class CLIPSegTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width] if self.eos_token_id == 2:
# take features from the eot embedding (eot_token is the highest number in each sequence) # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
pooled_output = last_hidden_state[ # ------------------------------------------------------------
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), # text_embeds.shape = [batch_size, sequence_length, transformer.width]
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), # take features from the eot embedding (eot_token is the highest number in each sequence)
] # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooled_output) + encoder_outputs[1:]
......
...@@ -106,8 +106,8 @@ class GroupViTTextConfig(PretrainedConfig): ...@@ -106,8 +106,8 @@ class GroupViTTextConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
initializer_factor=1.0, initializer_factor=1.0,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=49406,
eos_token_id=2, eos_token_id=49407,
**kwargs, **kwargs,
): ):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......
...@@ -1095,6 +1095,9 @@ class GroupViTTextTransformer(nn.Module): ...@@ -1095,6 +1095,9 @@ class GroupViTTextTransformer(nn.Module):
self.encoder = GroupViTTextEncoder(config) self.encoder = GroupViTTextEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
@add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig)
def forward( def forward(
...@@ -1144,13 +1147,26 @@ class GroupViTTextTransformer(nn.Module): ...@@ -1144,13 +1147,26 @@ class GroupViTTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width] if self.eos_token_id == 2:
# take features from the eot embedding (eot_token is the highest number in each sequence) # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
pooled_output = last_hidden_state[ # ------------------------------------------------------------
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), # text_embeds.shape = [batch_size, sequence_length, transformer.width]
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), # take features from the eot embedding (eot_token is the highest number in each sequence)
] # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
]
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooled_output) + encoder_outputs[1:]
......
...@@ -1002,6 +1002,9 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer): ...@@ -1002,6 +1002,9 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
epsilon=config.layer_norm_eps, name="final_layer_norm" epsilon=config.layer_norm_eps, name="final_layer_norm"
) )
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def call( def call(
self, self,
input_ids: TFModelInputType, input_ids: TFModelInputType,
...@@ -1038,14 +1041,30 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer): ...@@ -1038,14 +1041,30 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
sequence_output = self.final_layer_norm(inputs=sequence_output) sequence_output = self.final_layer_norm(inputs=sequence_output)
# text_embeds.shape = [batch_size, n_ctx, transformer.width] if self.eos_token_id == 2:
# take features from the eot embedding (eot_token is the highest number in each sequence) # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
pooled_output = tf.gather_nd( # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
params=sequence_output, # ------------------------------------------------------------
indices=tf.stack( # text_embeds.shape = [batch_size, n_ctx, transformer.width]
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1 # take features from the eot embedding (eot_token is the highest number in each sequence)
), pooled_output = tf.gather_nd(
) params=sequence_output,
indices=tf.stack(
values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
),
)
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = tf.gather_nd(
params=sequence_output,
indices=tf.stack(
values=(
tf.range(input_shape[0], dtype=tf.int64),
tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
),
axis=1,
),
)
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment