Unverified Commit 29c74f58 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

fix detr device map (#27089)

* fix detr device map

* add comments
parent ffff9e70
......@@ -1823,6 +1823,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
......
......@@ -1775,6 +1775,8 @@ class DetaModel(DetaPreTrainedModel):
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
def __init__(self, config: DetaConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment