Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d4575e5b
"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "133d00654cd52a43158e235c2003ed7f3322d70d"
Unverified
Commit
d4575e5b
authored
Feb 14, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 14, 2023
Browse files
Let LinearTransformation return datapoints instead of tensors (#7244)
parent
3a0e028f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
11 deletions
+10
-11
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+10
-11
No files found.
torchvision/prototype/transforms/_misc.py
View file @
d4575e5b
...
@@ -76,12 +76,7 @@ class LinearTransformation(Transform):
...
@@ -76,12 +76,7 @@ class LinearTransformation(Transform):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
raise
TypeError
(
"LinearTransformation does not work on PIL Images"
)
raise
TypeError
(
"LinearTransformation does not work on PIL Images"
)
def
_transform
(
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
self
,
inpt
:
Union
[
datapoints
.
TensorImageType
,
datapoints
.
TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
shape
=
inpt
.
shape
shape
=
inpt
.
shape
n
=
shape
[
-
3
]
*
shape
[
-
2
]
*
shape
[
-
1
]
n
=
shape
[
-
3
]
*
shape
[
-
2
]
*
shape
[
-
1
]
if
n
!=
self
.
transformation_matrix
.
shape
[
0
]:
if
n
!=
self
.
transformation_matrix
.
shape
[
0
]:
...
@@ -97,11 +92,15 @@ class LinearTransformation(Transform):
...
@@ -97,11 +92,15 @@ class LinearTransformation(Transform):
f
"Got
{
inpt
.
device
}
vs
{
self
.
mean_vector
.
device
}
"
f
"Got
{
inpt
.
device
}
vs
{
self
.
mean_vector
.
device
}
"
)
)
flat_tensor
=
inpt
.
reshape
(
-
1
,
n
)
-
self
.
mean_vector
flat_inpt
=
inpt
.
reshape
(
-
1
,
n
)
-
self
.
mean_vector
transformation_matrix
=
self
.
transformation_matrix
.
to
(
flat_inpt
.
dtype
)
output
=
torch
.
mm
(
flat_inpt
,
transformation_matrix
)
output
=
output
.
reshape
(
shape
)
transformation_matrix
=
self
.
transformation_matrix
.
to
(
flat_tensor
.
dtype
)
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
transformed_tensor
=
torch
.
mm
(
flat_tensor
,
transformation_matrix
)
output
=
type
(
inpt
).
wrap_like
(
inpt
,
output
)
# type: ignore[arg-type]
return
transformed_tensor
.
reshape
(
shape
)
return
output
class
Normalize
(
Transform
):
class
Normalize
(
Transform
):
...
@@ -120,7 +119,7 @@ class Normalize(Transform):
...
@@ -120,7 +119,7 @@ class Normalize(Transform):
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
TensorImageType
,
datapoints
.
TensorVideoType
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
datapoints
.
TensorImageType
,
datapoints
.
TensorVideoType
],
params
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
)
->
Any
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment