"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b424f0b4a301abcbf3c282114159371ee44c3e01"
Unverified Commit 29c74f58 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

fix detr device map (#27089)

* fix detr device map

* add comments
parent ffff9e70
...@@ -1823,6 +1823,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): ...@@ -1823,6 +1823,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"] _tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None
def __init__(self, config: DeformableDetrConfig): def __init__(self, config: DeformableDetrConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1775,6 +1775,8 @@ class DetaModel(DetaPreTrainedModel): ...@@ -1775,6 +1775,8 @@ class DetaModel(DetaPreTrainedModel):
class DetaForObjectDetection(DetaPreTrainedModel): class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"] _tied_weights_keys = [r"bbox_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
def __init__(self, config: DetaConfig): def __init__(self, config: DetaConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment