Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bb4f816a
Unverified
Commit
bb4f816a
authored
Feb 29, 2024
by
NielsRogge
Committed by
GitHub
Feb 29, 2024
Browse files
Patch YOLOS and others (#29353)
Fix issue
parent
44fe1a1c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
45 additions
and
36 deletions
+45
-36
src/transformers/models/conditional_detr/modeling_conditional_detr.py
...mers/models/conditional_detr/modeling_conditional_detr.py
+4
-3
src/transformers/models/deformable_detr/modeling_deformable_detr.py
...ormers/models/deformable_detr/modeling_deformable_detr.py
+4
-3
src/transformers/models/deta/modeling_deta.py
src/transformers/models/deta/modeling_deta.py
+4
-3
src/transformers/models/detr/modeling_detr.py
src/transformers/models/detr/modeling_detr.py
+4
-3
src/transformers/models/mask2former/modeling_mask2former.py
src/transformers/models/mask2former/modeling_mask2former.py
+7
-6
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+7
-6
src/transformers/models/oneformer/modeling_oneformer.py
src/transformers/models/oneformer/modeling_oneformer.py
+7
-6
src/transformers/models/table_transformer/modeling_table_transformer.py
...rs/models/table_transformer/modeling_table_transformer.py
+4
-3
src/transformers/models/yolos/modeling_yolos.py
src/transformers/models/yolos/modeling_yolos.py
+4
-3
No files found.
src/transformers/models/conditional_detr/modeling_conditional_detr.py
View file @
bb4f816a
...
...
@@ -2514,9 +2514,10 @@ class ConditionalDetrLoss(nn.Module):
num_boxes
=
torch
.
as_tensor
([
num_boxes
],
dtype
=
torch
.
float
,
device
=
next
(
iter
(
outputs
.
values
())).
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
num_boxes
=
torch
.
clamp
(
num_boxes
/
world_size
,
min
=
1
).
item
()
# Compute all the requested losses
...
...
src/transformers/models/deformable_detr/modeling_deformable_detr.py
View file @
bb4f816a
...
...
@@ -2282,9 +2282,10 @@ class DeformableDetrLoss(nn.Module):
num_boxes
=
sum
(
len
(
t
[
"class_labels"
])
for
t
in
targets
)
num_boxes
=
torch
.
as_tensor
([
num_boxes
],
dtype
=
torch
.
float
,
device
=
next
(
iter
(
outputs
.
values
())).
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
num_boxes
=
torch
.
clamp
(
num_boxes
/
world_size
,
min
=
1
).
item
()
# Compute all the requested losses
...
...
src/transformers/models/deta/modeling_deta.py
View file @
bb4f816a
...
...
@@ -2345,9 +2345,10 @@ class DetaLoss(nn.Module):
num_boxes
=
torch
.
as_tensor
([
num_boxes
],
dtype
=
torch
.
float
,
device
=
next
(
iter
(
outputs
.
values
())).
device
)
# Check that we have initialized the distributed state
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
num_boxes
=
torch
.
clamp
(
num_boxes
/
world_size
,
min
=
1
).
item
()
# Compute all the requested losses
...
...
src/transformers/models/detr/modeling_detr.py
View file @
bb4f816a
...
...
@@ -2210,9 +2210,10 @@ class DetrLoss(nn.Module):
num_boxes
=
sum
(
len
(
t
[
"class_labels"
])
for
t
in
targets
)
num_boxes
=
torch
.
as_tensor
([
num_boxes
],
dtype
=
torch
.
float
,
device
=
next
(
iter
(
outputs
.
values
())).
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
num_boxes
=
torch
.
clamp
(
num_boxes
/
world_size
,
min
=
1
).
item
()
# Compute all the requested losses
...
...
src/transformers/models/mask2former/modeling_mask2former.py
View file @
bb4f816a
...
...
@@ -791,14 +791,15 @@ class Mask2FormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks
=
sum
([
len
(
classes
)
for
classes
in
class_labels
])
num_masks
_pt
=
torch
.
as_tensor
(
num_masks
,
dtype
=
torch
.
float
,
device
=
device
)
num_masks
=
torch
.
as_tensor
(
num_masks
,
dtype
=
torch
.
float
,
device
=
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_masks_pt
=
reduce
(
num_masks_pt
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_masks
=
reduce
(
num_masks
)
world_size
=
PartialState
().
num_processes
num_masks
_pt
=
torch
.
clamp
(
num_masks
_pt
/
world_size
,
min
=
1
)
return
num_masks
_pt
num_masks
=
torch
.
clamp
(
num_masks
/
world_size
,
min
=
1
)
return
num_masks
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
...
...
src/transformers/models/maskformer/modeling_maskformer.py
View file @
bb4f816a
...
...
@@ -1198,14 +1198,15 @@ class MaskFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks
=
sum
([
len
(
classes
)
for
classes
in
class_labels
])
num_masks
_pt
=
torch
.
as_tensor
(
num_masks
,
dtype
=
torch
.
float
,
device
=
device
)
num_masks
=
torch
.
as_tensor
(
num_masks
,
dtype
=
torch
.
float
,
device
=
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_masks_pt
=
reduce
(
num_masks_pt
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_masks
=
reduce
(
num_masks
)
world_size
=
PartialState
().
num_processes
num_masks
_pt
=
torch
.
clamp
(
num_masks
_pt
/
world_size
,
min
=
1
)
return
num_masks
_pt
num_masks
=
torch
.
clamp
(
num_masks
/
world_size
,
min
=
1
)
return
num_masks
class
MaskFormerFPNConvLayer
(
nn
.
Module
):
...
...
src/transformers/models/oneformer/modeling_oneformer.py
View file @
bb4f816a
...
...
@@ -727,14 +727,15 @@ class OneFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks
=
sum
([
len
(
classes
)
for
classes
in
class_labels
])
num_masks
_pt
=
torch
.
as_tensor
([
num_masks
],
dtype
=
torch
.
float
,
device
=
device
)
num_masks
=
torch
.
as_tensor
([
num_masks
],
dtype
=
torch
.
float
,
device
=
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_masks_pt
=
reduce
(
num_masks_pt
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_masks
=
reduce
(
num_masks
)
world_size
=
PartialState
().
num_processes
num_masks
_pt
=
torch
.
clamp
(
num_masks
_pt
/
world_size
,
min
=
1
)
return
num_masks
_pt
num_masks
=
torch
.
clamp
(
num_masks
/
world_size
,
min
=
1
)
return
num_masks
@
dataclass
...
...
src/transformers/models/table_transformer/modeling_table_transformer.py
View file @
bb4f816a
...
...
@@ -1757,9 +1757,10 @@ class TableTransformerLoss(nn.Module):
num_boxes
=
sum
(
len
(
t
[
"class_labels"
])
for
t
in
targets
)
num_boxes
=
torch
.
as_tensor
([
num_boxes
],
dtype
=
torch
.
float
,
device
=
next
(
iter
(
outputs
.
values
())).
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
num_boxes
=
torch
.
clamp
(
num_boxes
/
world_size
,
min
=
1
).
item
()
# Compute all the requested losses
...
...
src/transformers/models/yolos/modeling_yolos.py
View file @
bb4f816a
...
...
@@ -1079,9 +1079,10 @@ class YolosLoss(nn.Module):
num_boxes
=
sum
(
len
(
t
[
"class_labels"
])
for
t
in
targets
)
num_boxes
=
torch
.
as_tensor
([
num_boxes
],
dtype
=
torch
.
float
,
device
=
next
(
iter
(
outputs
.
values
())).
device
)
world_size
=
1
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
if
is_accelerate_available
():
if
PartialState
.
_shared_state
!=
{}:
num_boxes
=
reduce
(
num_boxes
)
world_size
=
PartialState
().
num_processes
num_boxes
=
torch
.
clamp
(
num_boxes
/
world_size
,
min
=
1
).
item
()
# Compute all the requested losses
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment