Unverified Commit f5cd2769 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[FlaxCLIP] allow passing params to image and text feature methods (#13099)

* allow passing params to image and text feature method

* ifx for hybrid clip as well
parent 9a498c37
...@@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
attention_mask=None, attention_mask=None,
position_ids=None, position_ids=None,
token_type_ids=None, token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train=False, train=False,
): ):
...@@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
return text_features return text_features
return self.module.apply( return self.module.apply(
{"params": self.params}, {"params": params or self.params},
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"),
...@@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
rngs=rngs, rngs=rngs,
) )
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False): def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r""" r"""
Args: Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`): pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
...@@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
return image_features return image_features
return self.module.apply( return self.module.apply(
{"params": self.params}, {"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32), jnp.array(pixel_values, dtype=jnp.float32),
not train, not train,
method=_get_features, method=_get_features,
......
...@@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
) )
def get_text_features( def get_text_features(
self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
): ):
r""" r"""
Args: Args:
...@@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
return text_features return text_features
return self.module.apply( return self.module.apply(
{"params": self.params}, {"params": params or self.params},
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"),
...@@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
rngs=rngs, rngs=rngs,
) )
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False): def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r""" r"""
Args: Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`): pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
...@@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
return image_features return image_features
return self.module.apply( return self.module.apply(
{"params": self.params}, {"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32), jnp.array(pixel_values, dtype=jnp.float32),
not train, not train,
method=_get_features, method=_get_features,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment