Unverified Commit c73e3532 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

push (#11750)

parent 936b5715
...@@ -551,7 +551,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -551,7 +551,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment