Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
21546e59
Unverified
Commit
21546e59
authored
Nov 12, 2021
by
Suraj Patil
Committed by
GitHub
Nov 12, 2021
Browse files
fix docs (#14377)
parent
ed5d1551
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+4
-4
No files found.
src/transformers/modeling_flax_utils.py
View file @
21546e59
...
@@ -205,7 +205,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -205,7 +205,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
>>> from flax import traverse_util
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-
1
] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = {path: (path[-
2
] !=
("LayerNorm",
"bias"
)
and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
>>> model.params = model.to_bf16(model.params, mask)
"""
"""
...
@@ -255,10 +255,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -255,10 +255,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Examples::
Examples::
>>> from transformers import FlaxBertModel
>>> from transformers import FlaxBertModel
>>> #
Down
load model
and configuration from huggingface.co
>>> # load model
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # By default, the model params will be in fp32, to cast these to float16
>>> # By default, the model params will be in fp32, to cast these to float16
>>> model.params = model.to_f16(model.params)
>>> model.params = model.to_f
p
16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> from flax import traverse_util
...
@@ -266,7 +266,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -266,7 +266,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
>>> mask = traverse_util.unflatten_dict(mask)
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_f16(model.params, mask)
>>> model.params = model.to_f
p
16(model.params, mask)
"""
"""
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment