Unverified Commit 21546e59 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix docs (#14377)

parent ed5d1551
...@@ -205,7 +205,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -205,7 +205,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
>>> from flax import traverse_util >>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased') >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params) >>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} >>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask) >>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask) >>> model.params = model.to_bf16(model.params, mask)
""" """
...@@ -255,10 +255,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -255,10 +255,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Examples:: Examples::
>>> from transformers import FlaxBertModel >>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co >>> # load model
>>> model = FlaxBertModel.from_pretrained('bert-base-cased') >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16 >>> # By default, the model params will be in fp32, to cast these to float16
>>> model.params = model.to_f16(model.params) >>> model.params = model.to_fp16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows >>> # then pass the mask as follows
>>> from flax import traverse_util >>> from flax import traverse_util
...@@ -266,7 +266,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -266,7 +266,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
>>> flat_params = traverse_util.flatten_dict(model.params) >>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} >>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask) >>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_f16(model.params, mask) >>> model.params = model.to_fp16(model.params, mask)
""" """
return self._cast_floating_to(params, jnp.float16, mask) return self._cast_floating_to(params, jnp.float16, mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment