Unverified Commit f5cd2769 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[FlaxCLIP] allow passing params to image and text feature methods (#13099)

* allow passing params to image and text feature method

* ifx for hybrid clip as well
parent 9a498c37
......@@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
attention_mask=None,
position_ids=None,
token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
......@@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
return text_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
......@@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
rngs=rngs,
)
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
......@@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
return image_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,
......
......@@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
)
def get_text_features(
self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
r"""
Args:
......@@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
return text_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
......@@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
rngs=rngs,
)
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
......@@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
return image_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment