Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4b919657
Unverified
Commit
4b919657
authored
Feb 17, 2021
by
Lysandre Debut
Committed by
GitHub
Feb 17, 2021
Browse files
Factor out methods (#10215)
parent
e94d63f6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
26 deletions
+34
-26
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+34
-26
No files found.
src/transformers/modeling_utils.py
View file @
4b919657
...
...
@@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices(
return
heads
,
index
def
get_parameter_device
(
parameter
:
Union
[
nn
.
Module
,
GenerationMixin
,
"ModuleUtilsMixin"
]):
try
:
return
next
(
parameter
.
parameters
()).
device
except
StopIteration
:
# For nn.DataParallel compatibility in PyTorch 1.5
def
find_tensor_attributes
(
module
:
nn
.
Module
)
->
List
[
Tuple
[
str
,
Tensor
]]:
tuples
=
[(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)]
return
tuples
gen
=
parameter
.
_named_members
(
get_members_fn
=
find_tensor_attributes
)
first_tuple
=
next
(
gen
)
return
first_tuple
[
1
].
device
def
get_parameter_dtype
(
parameter
:
Union
[
nn
.
Module
,
GenerationMixin
,
"ModuleUtilsMixin"
]):
try
:
return
next
(
parameter
.
parameters
()).
dtype
except
StopIteration
:
# For nn.DataParallel compatibility in PyTorch 1.5
def
find_tensor_attributes
(
module
:
nn
.
Module
)
->
List
[
Tuple
[
str
,
Tensor
]]:
tuples
=
[(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)]
return
tuples
gen
=
parameter
.
_named_members
(
get_members_fn
=
find_tensor_attributes
)
first_tuple
=
next
(
gen
)
return
first_tuple
[
1
].
dtype
class
ModuleUtilsMixin
:
"""
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
...
...
@@ -145,36 +175,14 @@ class ModuleUtilsMixin:
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
try
:
return
next
(
self
.
parameters
()).
device
except
StopIteration
:
# For nn.DataParallel compatibility in PyTorch 1.5
def
find_tensor_attributes
(
module
:
nn
.
Module
)
->
List
[
Tuple
[
str
,
Tensor
]]:
tuples
=
[(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)]
return
tuples
gen
=
self
.
_named_members
(
get_members_fn
=
find_tensor_attributes
)
first_tuple
=
next
(
gen
)
return
first_tuple
[
1
].
device
return
get_parameter_device
(
self
)
@
property
def
dtype
(
self
)
->
dtype
:
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try
:
return
next
(
self
.
parameters
()).
dtype
except
StopIteration
:
# For nn.DataParallel compatibility in PyTorch 1.5
def
find_tensor_attributes
(
module
:
nn
.
Module
)
->
List
[
Tuple
[
str
,
Tensor
]]:
tuples
=
[(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)]
return
tuples
gen
=
self
.
_named_members
(
get_members_fn
=
find_tensor_attributes
)
first_tuple
=
next
(
gen
)
return
first_tuple
[
1
].
dtype
return
get_parameter_dtype
(
self
)
def
invert_attention_mask
(
self
,
encoder_attention_mask
:
Tensor
)
->
Tensor
:
"""
...
...
@@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module):
x
=
self
.
dense
(
hidden_states
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
if
next
(
self
.
parameters
()).
dtype
==
torch
.
float16
:
if
get_parameter_dtype
(
self
)
==
torch
.
float16
:
x
=
x
*
(
1
-
p_mask
)
-
65500
*
p_mask
else
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
...
...
@@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module):
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
if
next
(
self
.
parameters
()).
dtype
==
torch
.
float16
:
if
get_parameter_dtype
(
self
)
==
torch
.
float16
:
x
=
x
*
(
1
-
p_mask
)
-
65500
*
p_mask
else
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment